
UnslothでDPOをやってみた
はじめに
Unslothを使ってDPO(Direct Preference Optimization)をやってみます。
DPOは通常、指示チューニング後にLLMの出力を、よりユーザーの好みに調整するための手法です。
詳しい説明はHuggingFaceの記事に任せることにします。
この記事では実装のみを載せることにします。
実装
準備
まず、ライブラリ群のインポートとパラメータの設定、モデルのロードまでを行います。
import os
from unsloth import is_bfloat16_supported
from unsloth import FastLanguageModel, PatchDPOTrainer
PatchDPOTrainer()
from trl import DPOTrainer, DPOConfig
from datasets import load_dataset
import wandb
os.environ["WANDB_PROJECT"] = "llama3.2-dpo"
# 基本設定
max_seq_length = 768
dtype = None
load_in_4bit = True
num_proc = 4
sft_model = 'sft_model'
dpo_output = 'dpo_output'
random_seed = 3407
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = sft_model,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
)
独自モデルを読み込んでいますが、このモデルの継続事前学習と指示チューニングについては以下の記事にコードを公開しています。
上の記事に記載していますが、元のモデルはLlama3.2 1Bモデルを使用しています。
データセット読み込み
データセットは以下を使用しています。
DPOでは、データセットのキーとしてchose, rejected, promptを取る必要があります。choseはLLMが応答すべき回答で、rejectedはLLMが応答すべきでない回答となっています。promptはユーザーの入力です。
# データダウンロード
dataset = load_dataset("ryota39/truthy-dpo-ja")
print(dataset['train'][0])
system_template = """
<|start_header_id|>system<|end_header_id|>
{system}
<|eot_id|>
"""
user_template = """
<|start_header_id|>user<|end_header_id|>
{user}
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
"""
assistant_template = """
{assistant}
<|eot_id|>
"""
def apply_chat_template(example):
prompt = system_template.format(system=example['system'])
prompt += user_template.format(user=example['prompt'])
example["chosen"] = assistant_template.format(assistant=example['chosen'])
example["rejected"] = assistant_template.format(assistant=example['rejected'])
example["prompt"] = prompt
return example
dataset = dataset.map(
apply_chat_template,
num_proc = num_proc,
)
dataset
学習設定
学習パラメータは以下の通り。
model = FastLanguageModel.get_peft_model(
model,
r = 8,
target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_alpha = 8,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
use_gradient_checkpointing=False,
random_state = random_seed,
max_seq_length = max_seq_length,
use_rslora=True,
)
trainer = DPOTrainer(
model = model,
ref_model = None,
train_dataset = dataset['train'],
tokenizer = tokenizer,
beta=0.1,
max_length = max_seq_length,
max_prompt_length = 256,
args = DPOConfig(
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
num_train_epochs=3,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 10,
save_total_limit = 2,
output_dir = dpo_output,
learning_rate = 5e-6,
warmup_steps=10,
lr_scheduler_type="linear",
optim="adamw_8bit",
seed = random_seed,
report_to = "wandb",
),
)
指示チューニング(SFT)と比べて特筆すべきパラメータはref_modelとbetaかと思います。
ref_modelの箇所は、LLMの応答文に対して応答の良さを数値化するための参照モデルを取ることができます。しかし、最近のTRLやUnslothでは参照モデルの指定をNoneとすることができます。(この場合、ベースモデルのトップ層にアダプターを付与して参照モデルとして使用しているようです)
betaは、値が小さくなるほど参照モデルを無視するという意味で、参照モデルにどの程度注意を払うかを指定します。
学習の実行
trainer_stats = trainer.train()
wandb.finish()
モデルの保存
モデルの保存はLoRAのアダプターのみを保存します。
model.save_pretrained_merged("dpo_model_lora", tokenizer, save_method = "lora",)