見出し画像

DPO による Llama 2 のファインチューニング

以下の記事が面白かったので、かるくまとめました。

Fine-tune Llama 2 with DPO

1. はじめに

「RLHF」は「GPT-4」「Claude」などのLLMの事実上の最後の学習ステップとなっており、LLM出力の饒舌さや安全さが人間の期待と一致していることを確認します。ただし、RLの複雑さが持ち込まれます。適切な報酬関数を設定し、状態を推定するようにモデルを学習する必要があります。同時に、元のモデルから離れすぎないよう注意する必要があります。このようなプロセスは非常に複雑で、正しく行うのは容易ではありません。

Rafailov、Sharma、Mitchellらによる最近の論文「Direct Preference Optimization」では、既存の手法で使用されているRLベースの目標を、単純なバイナリクロスエントロピー損失を介して直接最適化できる目標に切り替えることを提案しています。これにより、LLMを改良するこのプロセスが大幅に簡素化できます。

この記事では、TRL で利用できるようになった「DPO」(Direct Preference Optimization) を使って、stack-exchange preference データセットで 「Llama v2 7B」をファインチューニングする手順を紹介します。

2. DPO vs PPO

RLによって人間の好みを最適化する伝統的なモデルでは、補助的な報酬モデルを使用し、RLによって与えられた報酬を最大化するように、対象のモデルをファインチューニングするのが一般的です。直感的には、報酬モデルを使用して、最適化しているモデルにフィードバックを与え、高報酬のサンプルをより頻繁に生成し、低報酬のサンプルをより頻繁に生成しないようにします。同時に、生成されるものが大きく逸脱しないように、また生成の多様性を維持し続けるように、固定された参照モデルを使用します。これは通常、参照モデルを介して完全な報酬最大化目標にKLペナルティを追加することで行われ、モデルが報酬モデルをごまかすことを学習しないようにします。

DPOの定式化は、報酬モデリングのステップをバイパスし、重要な洞察を介して嗜好データの言語モデルを直接最適化します。つまり、報酬関数から最適なRLポリシーへの分析マッピングにより、作成者は報酬モデルと参照モデルに対する RL 損失を変換できます。参照モデルを直接上回る損失が発生します。このマッピングは、特定の報酬関数が特定の嗜好データとどの程度一致しているかを直感的に測定します。したがって、DPOはRLHF損失の最適解から開始し、変数の変更を通じて参照モデルのみの損失を導き出します。

この目標は、報酬モデルを必要とせず、または潜在的に面倒なRLベースの最適化を実行する必要もなく、最適化できます。

3. TRLを使った学習方法

通常、RLHF パイプラインは次の異なるパーツで構成されます。

(1) 教師ありファインチューニング (SFT)
(2) データへの設定ラベルの付加
(3) 嗜好データに基づいて報酬モデルを学習
(4) RL最適化

TRLにはこれらすべてのパーツのヘルパーが付属していますが、DPOでは報酬モデリングと RL ((3)(4)) が不要になり、プリファレンスの注釈付きデータに基づいて DPOオブジェクトが直接最適化されます。

(1) を実行する必要がありますが、(3)(4) の代わりに、非常に特殊な形式、つまり次の3つのキーを含む辞書を持つ(2)の設定データをTRLの「DPOTrainer」に渡する必要があります。

・prompt : テキスト生成の推論時にモデルに与えられるプロンプト
・chosen : プロンプトに対する好ましい応答
・rejected : プロンプトに対する好ましくない応答

stack-exchange preference pairs データセットの場合、次のヘルパーを介して目的の辞書を返し、元の列をすべて削除するようにデータセットエントリをマップできます。

def return_prompt_and_responses(samples) -> Dict[str, str, str]:
    return {
        "prompt": [
            "Question: " + question + "\n\nAnswer: "
            for question in samples["question"]
        ],
        "chosen": samples["response_j"],   # rated better than k
        "rejected": samples["response_k"], # rated worse than j
    }

dataset = load_dataset(
    "lvwerra/stack-exchange-paired",
    split="train",
    data_dir="data/rl"
)
original_columns = dataset.column_names

dataset.map(
    return_prompt_and_responses,
    batched=True,
    remove_columns=original_columns
)

データセットを並べ替えると、DPO損失は基本的に、参照モデルを介して暗黙的な報酬を取得する教師あり損失となるため、高レベルでは、「DPOTrainer」には最適化したい「ベースモデル」と「参照モデル」が必要です。

dpo_trainer = DPOTrainer(
    model,                 # ベースモデル
    model_ref,             # 参照モデル (ベースモデルのコピー)
    beta=0.1,              # DPO温度
    train_dataset=dataset, # データセット
    tokenizer=tokenizer,   # トークナイザー
    args=training_args,    # 学習パラメータ (バッチサイズ、学習率など)
)

ここで、「beta」はDPO損失の温度であり、通常は 0.1 ~ 0.5 の範囲にあります。 これは、betaが小さくなるほど参照モデルを無視するという意味で、参照モデルにどの程度注意を払うかを指定します。トレーナーを初期化したら、次のメソッドを呼び出すだけで、学習開始できます。

dpo_trainer.train()

4. Llama 2 のファインチューニング

TRLにDPO トレーナーを実装する利点は、TRL とその依存ライブラリ (Peft や Accelerate など) に付属する大規模なLLMの学習に関する追加機能をすべて利用できることです。これらのライブラリを使用すると、bitsandbytesによって提供されるQLoRAを使用してLlama v2を学習することもできます。

4-1. 教師ありファインチューニング

「Llama 2 7B」で QLoRA を使用する教師ありファインチューニングを行います。

# 4bit量子化でベースモデルをロード
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

base_model = AutoModelForCausalLM.from_pretrained(
    script_args.model_name,        # "meta-llama/Llama-2-7b-hf"
    quantization_config=bnb_config,
    device_map={"": 0},
    trust_remote_code=True,
    use_auth_token=True,
)
base_model.config.use_cache = False

# 量子化されたベースモデルに LoRAレイヤーを追加
peft_config = LoraConfig(
    r=script_args.lora_r,
    lora_alpha=script_args.lora_alpha,
    lora_dropout=script_args.lora_dropout,
    target_modules=["q_proj", "v_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)
...
trainer = SFTTrainer(
    model=base_model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
    packing=True,
    max_seq_length=None,
    tokenizer=tokenizer,
    args=training_args,         # HFトレーナー引数
)
trainer.train()

4-2. DPO トレーニング

SFTが完了したら、結果のモデルを保存し、DPOトレーニングに進むことができます。通常行われているように、前のSFT ステップで保存したモデルを、DPOのベースモデルと参照モデルの両方に利用します。次に、これらを使用して、上に示したデータの DPO目標でモデルを学習できます。 モデルは LoRaアダプター経由で学習されたため、Peftの AutoPeftModelForCausalLM 経由でロードします。

model = AutoPeftModelForCausalLM.from_pretrained(
    script_args.model_name_or_path, # 保存された SFT モデルの場所
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
    is_trainable=True,
)
model_ref = AutoPeftModelForCausalLM.from_pretrained(
    script_args.model_name_or_path,  # メインと同じモデル
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
)
...
dpo_trainer = DPOTrainer(
    model,
    model_ref,
    args=training_args,
    beta=script_args.beta,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    peft_config=peft_config,
)
dpo_trainer.train()
dpo_trainer.save_model()

モデルを4bitでロードし、peft_config引数を介してQLoraで学習します。トレーナーは、評価データセットに関してト学習中の進捗状況も評価し、たとえば wandbを介して記録および表示できる暗黙的報酬などの多くの重要な指標をレポートします。その後、最終的に学習されたモデルを HuggingFace Hub にプッシュできます。

5. おわりに

SFTおよびDPOの学習スクリプトの完全なソースコードは、examples/stack_llama_2 にあり、マージされたアダプターを使用して学習されたモデルは、こちらにあります。

DPOトレーニング実行の wandbログはここで参照できます。学習と評価中に、DPOTrainerは次の報酬メトリックを記録します。

・rewards/chosen : betaでスケールされた、選択された応答のポリシーモデルと参照モデルの対数確率間の平均差
・rewards/rejected : ポリシー モデルと拒否された応答の参照モデルの対数確率間の平均差 (betaでスケール)
・rewards/accuracies : 選択された報酬が対応する拒否された報酬よりも高い頻度の平均
・rewards/margins : 選択された報酬と対応する拒否された報酬の間の平均差

直感的には、学習中にマージンが増加し、精度が 1.0 になること、つまり、選択された報酬が拒否された報酬よりも大きくなるようにしたいと考えます (またはマージンがゼロより大きい場合)。 これらのメトリクスは、何らかの評価データセットに対して計算できます。



いいなと思ったら応援しよう!