DPO(選好チューニング)でLLMを訓練する方法
この記事は、
DPOという手法を使って、
LLM(大規模言語モデル)を訓練する方法についてご紹介します。
DPOって何?
DPOはDirect Preference Optimizationの略で、
LLMが出力する回答を、
人間の好みに近づける手法のことです。
別名、選好チューニングと呼ばれています。
なぜDPOが必要なのか?
LLMは、膨大なデータセットで事前学習することで
一般的な知識やパターンを習得しますが、
あくまで、文法的に正しい文章を出力してくれるにとどまるだけで、人間の好みに合った応答を生成するわけではありません。
例えば、LLMのプロンプトに
「明日の会議について詳細を教えてください。」と入力した場合
LLMは、
「明日の会議は14時から始まります。議題は以下の通りです。」
と答える可能性があります。
文法的には正しいのですが、形式的すぎて、相手との親しみやすいトーンが欠けていて、応答が冷たく感じられないでしょうか?
また、LLMのプロンプトに
「最近、仕事で失敗してすごく落ち込んでいます。」と入力した場合、
LLMは、
「それは残念です。次回は成功することを願っています。」
と答える可能性があります。
これも文法的には正しいのですが、
ユーザーの感情に寄り添ったり(例: 「それは辛いですね。どんなことがあったのですか?」)してほしいところです。
これだと、冷淡な印象を与えたり、応答が共感的でないため、ユーザーの信頼感を損ねる可能性があります。
他にも、文法的には正しいが、
不適切、非倫理的、あるいは無関係な出力、有害な発言をするリスクもあります。
こうした問題を解決するのがDPOです。
DPOによって、出力が人間の好みにあうように事後訓練するわけです。
従来、LLMを人間の選好に合わせる手法としては、
強化学習(Reinforcement Learning)を利用するの一般的でした。
その具体的な例として、
「強化学習を利用した人間のフィードバックによる微調整(RLHF: Reinforcement Learning with Human Feedback)」があります。
RLHFは効果的な手法ですが、
モデルの設計や調整が複雑だったり、
挙動が不安定になったり、膨大な計算資源が必要だったりと、
いくつかの課題があったのです。
そこで登場したのがDPOです。
強化学習とは違って、シンプルな分類問題として
人間の好みに沿ったモデル調整を実現します。
このアプローチは、
複雑なモデルの設計が不要で、
安定した学習も可能であり、
かつ計算コストを大幅に削減できる点が特徴です。
DPOの仕組み
DPOは、
簡単に言うと、人間がどのような回答を好むかを学び、
それに基づいて良い回答を選ぶ方法です。
このプロセスは次のように進みます:
❶データ準備:
まず、ユーザーの質問(prompt)に対して、
好ましい回答(chosen)と、
好ましくない回答(rejected)をペアで集めます。
例えば、「好きな映画は何ですか?」という質問に対して、
「私はアクション映画が好きです」という回答が好ましい回答で、
「映画?興味ないです」という回答が好ましくない回答です。
❷選好学習:
次に、好ましい回答が選ばれやすく、
好ましくない回答が選ばれにくいように、
モデルを訓練します。
これには「クロスエントロピー損失」という方法を使って、
好ましい回答が出やすくなるように調整します。
❸パフォーマンスの評価:
モデルの訓練が終わったら、
実際にどれくらい好ましい回答が選ばれるかをテストします。
これは、
新しい質問に対してモデルがどのような回答をするかを見て、
それがどれだけユーザーの期待に合っているかで評価されます。
このようにして、DPOは常にユーザーのニーズに合わせて進化し、より良い対話を目指します。
DPOの実装
こちらの記事を参考にしています。
1. 必要なライブラリーのインストール
%%capture
!pip install unsloth
# Also get the latest nightly Unsloth!
!pip uninstall unsloth -y && pip install --upgrade --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git
!pip install -U accelerate # これを入れないとtrainerでエラーになる
!pip install --upgrade torch
!pip install --upgrade xformers
2. 必要なライブラリーのインポート
from unsloth import PatchDPOTrainer
PatchDPOTrainer()
from unsloth import FastLanguageModel
import torch
import json
from tqdm import tqdm
import re
from transformers import TrainingArguments
from trl import DPOTrainer, DPOConfig
from unsloth import is_bfloat16_supported
3. モデルのダウンロード
max_seq_length = 2048
dtype = None
load_in_4bit = True
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "llm-jp/llm-jp-3-3.7b-instruct",
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
)
4. LoRAの設定
model = FastLanguageModel.get_peft_model(
model,
r = 64,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 64,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = "unsloth",
random_state = 3407,
use_rslora = False,
loftq_config = None,
)
5.データセットの準備
正直、データの質はあまり良いとは言えません。なぜ、これをllm-jpが公開しているのかが疑問なぐらいです。
6.データセットのダウンロードと整形
from datasets import load_dataset
dataset = load_dataset("llm-jp/hh-rlhf-12k-ja")
dataset
EOS_TOKEN = tokenizer.eos_token
# datesetの整形
# conversations から最初の human 発話を instruction として抽出
def extract_instruction(example):
conversations = example['conversations']
for message in conversations:
if message['from'] == 'human':
return message['value']
return None
# 各行を整形して新しい形式を作成
def format_for_dpo(example):
instruction = extract_instruction(example) +EOS_TOKEN # 指示文を抽出
chosen = example['chosen'] +EOS_TOKEN # chosen カラム
rejected = example['rejected'] +EOS_TOKEN # rejected カラム
return {
'instruction': instruction,
'chosen': chosen,
'rejected': rejected
}
# データセットを変換
formatted_dataset = dataset.map(format_for_dpo)
# 必要ないカラムを削除
formatted_dataset = formatted_dataset.remove_columns(['conversations', 'source',])
dpo_datasets = formatted_dataset['train'].rename_column("instruction", "prompt")
dpo_datasets = dpo_datasets.train_test_split(train_size=0.30)["train"]
dpo_datasets
7.DPO学習用のパラメータの設定
#メモリオーバーになるのでこちらを使う
dpo_trainer = DPOTrainer(
model = model,
ref_model = None,
args = DPOConfig(
per_device_train_batch_size = 1, # バッチサイズを削減
gradient_accumulation_steps = 8, # 勾配蓄積を増加
warmup_ratio = 0.05, # ウォームアップ期間を短縮
num_train_epochs = 1,
learning_rate = 5e-5,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 1,
optim = "adamw_8bit", # 8bit AdamW を維持
weight_decay = 0.0,
lr_scheduler_type = "linear",
seed = 42,
output_dir = "outputs",
report_to = "none",
),
beta = 0.1,
train_dataset = dpo_datasets,
tokenizer = tokenizer,
max_length = 512, # シーケンス長を短縮
max_prompt_length = 256, # プロンプト長を短縮
)
8. 学習の実行
dpo_trainer.train()
A100のGPUで20分ぐらいです
9. モデルの保存
HuggingFaceにLoRAアダプタだけを保存します。
# LoRAアダプタだけ保存
new_model_id = "llm-jp-3-3.7b-instruct-DPO"
model.push_to_hub_merged(
"*****/"+new_model_id+"_lora",
tokenizer=tokenizer,
save_method="lora",
token=HF_TOKEN,
private=True
)
10.推論
以下の3つのプロンプトで、DPO前とDPO後でどのような変化があったかを見ていきます。
# 推論モードに切り替え
FastLanguageModel.for_inference(model)
input = "読書って意味があるんですか?"
prompt = f"""
### 指示\n{input}\n### 回答\n
"""
inputs = tokenizer([prompt], return_tensors = "pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens = 256,
use_cache = True,
do_sample=False,
repetition_penalty=1.2)
prediction = tokenizer.decode(
outputs[0], skip_special_tokens=True).split('\n### 回答')[-1]
11. 推論結果
プロンプト❶:健康になりたい
DPO前
DPO後
プロンプト❷:暇で暇でしようがないのですが、どうしたらいいですか?
DPO前
DPO後
プロンプト❸:読書って意味があるんですか?
DPO前
DPO後
いかがでしょうか?
DPO前は、総じて文章が「冗長的」で、途中で切れてしまうこともあるのに対し、DPO後は、相対的に文章がコンパクトになっていますね。途中で切れるということもありません。
他にもいくつか試してみましたが、同じような傾向が見られました。
3,000程度の学習データなので評価は難しいですが、これはこれでスゴイなと思いました。