見出し画像

UnslothでDPOをやってみた


はじめに

Unslothを使ってDPO(Direct Preference Optimization)をやってみます。
DPOは通常、指示チューニング後にLLMの出力を、よりユーザーの好みに調整するための手法です。
詳しい説明はHuggingFaceの記事に任せることにします。

この記事では実装のみを載せることにします。

実装

準備

まず、ライブラリ群のインポートとパラメータの設定、モデルのロードまでを行います。

import os
from unsloth import is_bfloat16_supported
from unsloth import FastLanguageModel, PatchDPOTrainer
PatchDPOTrainer()
from trl import DPOTrainer, DPOConfig
from datasets import load_dataset
import wandb
os.environ["WANDB_PROJECT"] = "llama3.2-dpo" 

# 基本設定
max_seq_length = 768
dtype = None
load_in_4bit = True
num_proc = 4
sft_model = 'sft_model'
dpo_output = 'dpo_output'
random_seed = 3407

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = sft_model,
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

独自モデルを読み込んでいますが、このモデルの継続事前学習と指示チューニングについては以下の記事にコードを公開しています。

上の記事に記載していますが、元のモデルはLlama3.2 1Bモデルを使用しています。

データセット読み込み

データセットは以下を使用しています。

DPOでは、データセットのキーとしてchose, rejected, promptを取る必要があります。choseはLLMが応答すべき回答で、rejectedはLLMが応答すべきでない回答となっています。promptはユーザーの入力です。

# データダウンロード
dataset = load_dataset("ryota39/truthy-dpo-ja")
print(dataset['train'][0])

system_template = """
<|start_header_id|>system<|end_header_id|>
{system}
<|eot_id|>
"""
user_template = """
<|start_header_id|>user<|end_header_id|>
{user}
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
"""
assistant_template = """
{assistant}
<|eot_id|>
"""

def apply_chat_template(example):
    prompt = system_template.format(system=example['system'])
    prompt += user_template.format(user=example['prompt'])
    
    example["chosen"] = assistant_template.format(assistant=example['chosen'])
    example["rejected"] = assistant_template.format(assistant=example['rejected'])
    example["prompt"] = prompt
    return example

dataset = dataset.map(
    apply_chat_template,
    num_proc = num_proc,
)

dataset

学習設定

学習パラメータは以下の通り。

model = FastLanguageModel.get_peft_model(
    model,
    r = 8,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = 8,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    use_gradient_checkpointing=False,
    random_state = random_seed,
    max_seq_length = max_seq_length,
    use_rslora=True,
)

trainer = DPOTrainer(
    model = model,
    ref_model = None,
    train_dataset = dataset['train'],
    tokenizer = tokenizer,
    beta=0.1,
    max_length = max_seq_length,
    max_prompt_length = 256,

    args = DPOConfig(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=8,
        num_train_epochs=3,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 10,
        save_total_limit = 2,
        output_dir = dpo_output,
        learning_rate = 5e-6,
        warmup_steps=10,
        lr_scheduler_type="linear",
        optim="adamw_8bit",
        seed = random_seed,
        report_to = "wandb",
  ),
)

指示チューニング(SFT)と比べて特筆すべきパラメータはref_modelとbetaかと思います。
ref_modelの箇所は、LLMの応答文に対して応答の良さを数値化するための参照モデルを取ることができます。しかし、最近のTRLやUnslothでは参照モデルの指定をNoneとすることができます。(この場合、ベースモデルのトップ層にアダプターを付与して参照モデルとして使用しているようです)
betaは、値が小さくなるほど参照モデルを無視するという意味で、参照モデルにどの程度注意を払うかを指定します。

学習の実行

trainer_stats = trainer.train()
wandb.finish()

モデルの保存

モデルの保存はLoRAのアダプターのみを保存します。

model.save_pretrained_merged("dpo_model_lora", tokenizer, save_method = "lora",)

いいなと思ったら応援しよう!