見出し画像

DPO(選好チューニング)でLLMを訓練する方法

この記事は、
DPOという手法を使って、
LLM(大規模言語モデル)
を訓練する方法についてご紹介します。




DPOって何?


DPOはDirect Preference Optimizationの略で、
LLMが出力する回答を、
人間の好みに近づける手法のことです。
別名、選好チューニングと呼ばれています。



なぜDPOが必要なのか?


LLMは、膨大なデータセットで事前学習することで
一般的な知識やパターンを習得しますが、
あくまで、文法的に正しい文章を出力してくれるにとどまるだけで、人間の好みに合った応答を生成するわけではありません。

例えば、LLMのプロンプトに
「明日の会議について詳細を教えてください。」と入力した場合
LLMは、
「明日の会議は14時から始まります。議題は以下の通りです。」
と答える可能性があります。
文法的には正しいのですが、形式的すぎて、相手との親しみやすいトーンが欠けていて、応答が冷たく感じられないでしょうか?

また、LLMのプロンプトに
「最近、仕事で失敗してすごく落ち込んでいます。」と入力した場合、
LLMは、
「それは残念です。次回は成功することを願っています。」
と答える可能性があります。
これも文法的には正しいのですが、
ユーザーの感情に寄り添ったり(例: 「それは辛いですね。どんなことがあったのですか?」)してほしいところです。
これだと、冷淡な印象を与えたり、応答が共感的でないため、ユーザーの信頼感を損ねる可能性があります。

他にも、文法的には正しいが、
不適切、非倫理的、あるいは無関係な出力、有害な発言をするリスクもあります。

こうした問題を解決するのがDPOです。
DPOによって、出力が人間の好みにあうように事後訓練するわけです。

従来、LLMを人間の選好に合わせる手法としては、
強化学習(Reinforcement Learning)
を利用するの一般的でした。
その具体的な例として、
「強化学習を利用した人間のフィードバックによる微調整(RLHF: Reinforcement Learning with Human Feedback)」があります。

RLHFは効果的な手法ですが、
モデルの設計や調整が複雑だったり、
挙動が不安定になったり、膨大な計算資源が必要だったりと、
いくつかの課題があったのです。

そこで登場したのがDPOです。
強化学習とは違って、シンプルな分類問題として
人間の好みに沿ったモデル調整を実現します。
このアプローチは、
複雑なモデルの設計が不要で、
安定した学習も可能であり、
かつ計算コストを大幅に削減できる点が特徴です。


DPOの仕組み

DPOは、
簡単に言うと、人間がどのような回答を好むかを学び、
それに基づいて良い回答を選ぶ方法です。

このプロセスは次のように進みます:

❶データ準備
まず、ユーザーの質問(prompt)に対して、
好ましい回答(chosen)と、
好ましくない回答(rejected)をペアで集めます。
例えば、「好きな映画は何ですか?」という質問に対して、
「私はアクション映画が好きです」という回答が好ましい回答で、
「映画?興味ないです」という回答が好ましくない回答です。

選好学習
次に、好ましい回答が選ばれやすく、
好ましくない回答が選ばれにくいように、
モデルを訓練します。
これには「クロスエントロピー損失」という方法を使って、
好ましい回答が出やすくなるように調整します。

❸パフォーマンスの評価
モデルの訓練が終わったら、
実際にどれくらい好ましい回答が選ばれるかをテストします。
これは、
新しい質問に対してモデルがどのような回答をするかを見て、
それがどれだけユーザーの期待に合っているかで評価されます。

このようにして、DPOは常にユーザーのニーズに合わせて進化し、より良い対話を目指します。


DPOの実装

こちらの記事を参考にしています。


1. 必要なライブラリーのインストール

公式ドキュメントには下3行が書かれていませんが、これらもインストール・アップグレードしないとエラーが出ます。

%%capture
!pip install unsloth
# Also get the latest nightly Unsloth!
!pip uninstall unsloth -y && pip install --upgrade --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git

!pip install -U accelerate # これを入れないとtrainerでエラーになる
!pip install --upgrade torch
!pip install --upgrade xformers


2. 必要なライブラリーのインポート


from unsloth import PatchDPOTrainer
PatchDPOTrainer()

from unsloth import FastLanguageModel
import torch
import json
from tqdm import tqdm
import re
from transformers import TrainingArguments
from trl import DPOTrainer, DPOConfig
from unsloth import is_bfloat16_supported


3. モデルのダウンロード

今回は、llm-jpが開発したllm-jp-3-3.7B-instruct(指示チューニング済みモデル)を使用します。モデルの大きさも手頃です。

max_seq_length = 2048
dtype = None
load_in_4bit = True

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "llm-jp/llm-jp-3-3.7b-instruct",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

4. LoRAの設定

get_peft_model()を使って、LoRAの設定をします。
LoRAの対象層は、"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"になります。

model = FastLanguageModel.get_peft_model(
    model,
    r = 64, 
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 64,
    lora_dropout = 0,
    bias = "none",    
    use_gradient_checkpointing = "unsloth", 
    random_state = 3407,
    use_rslora = False,  
    loftq_config = None, 
)


5.データセットの準備

DPOの学習用データとして、llm-jpが公開しているhh-rlhf-12k-jaというデータセットを使います。

正直、データの質はあまり良いとは言えません。なぜ、これをllm-jpが公開しているのかが疑問なぐらいです。

6.データセットのダウンロードと整形

12,000件の学習用データが格納されています。

from datasets import load_dataset
dataset = load_dataset("llm-jp/hh-rlhf-12k-ja")
dataset


DPOの学習に使えるようデータセットを整形します。
具体的には、conversationsカラムから質問文を抜き出し、
promptカラムに格納します。
また、全12,000データのうち3割(3600データ)をランダムサンプリングして、
学習用データセットにしています。

EOS_TOKEN = tokenizer.eos_token
# datesetの整形
# conversations から最初の human 発話を instruction として抽出
def extract_instruction(example):
    conversations = example['conversations']
    for message in conversations:
        if message['from'] == 'human':
            return message['value']
    return None

# 各行を整形して新しい形式を作成
def format_for_dpo(example):
    instruction = extract_instruction(example) +EOS_TOKEN  # 指示文を抽出
    chosen = example['chosen'] +EOS_TOKEN # chosen カラム
    rejected = example['rejected'] +EOS_TOKEN # rejected カラム
    return {
        'instruction': instruction,
        'chosen': chosen,
        'rejected': rejected
    }

# データセットを変換
formatted_dataset = dataset.map(format_for_dpo)

# 必要ないカラムを削除
formatted_dataset = formatted_dataset.remove_columns(['conversations', 'source',])

dpo_datasets = formatted_dataset['train'].rename_column("instruction", "prompt")
dpo_datasets = dpo_datasets.train_test_split(train_size=0.30)["train"]
dpo_datasets


7.DPO学習用のパラメータの設定

計算負荷を下げるためにいくつか調整しています。

#メモリオーバーになるのでこちらを使う
dpo_trainer = DPOTrainer(
    model = model,
    ref_model = None,
    args = DPOConfig(
        per_device_train_batch_size = 1,  # バッチサイズを削減
        gradient_accumulation_steps = 8,  # 勾配蓄積を増加
        warmup_ratio = 0.05,  # ウォームアップ期間を短縮
        num_train_epochs = 1,
        learning_rate = 5e-5,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",  # 8bit AdamW を維持
        weight_decay = 0.0,
        lr_scheduler_type = "linear",
        seed = 42,
        output_dir = "outputs",
        report_to = "none",
    ),
    beta = 0.1,
    train_dataset = dpo_datasets,
    tokenizer = tokenizer,
    max_length = 512,  # シーケンス長を短縮
    max_prompt_length = 256,  # プロンプト長を短縮
)

8. 学習の実行

dpo_trainer.train()

A100のGPUで20分ぐらいです


9. モデルの保存

HuggingFaceにLoRAアダプタだけを保存します。

# LoRAアダプタだけ保存
new_model_id = "llm-jp-3-3.7b-instruct-DPO"
model.push_to_hub_merged(
    "*****/"+new_model_id+"_lora",
    tokenizer=tokenizer,
    save_method="lora",
    token=HF_TOKEN,
    private=True
)

10.推論

以下の3つのプロンプトで、DPO前とDPO後でどのような変化があったかを見ていきます。

プロンプト❶:健康になりたい
プロンプト❷:暇で暇でしようがないのですが、どうしたらいいですか?
プロンプト❸:読書って意味があるんですか?

# 推論モードに切り替え
FastLanguageModel.for_inference(model)

input = "読書って意味があるんですか?"

prompt = f"""
### 指示\n{input}\n### 回答\n
"""

inputs = tokenizer([prompt], return_tensors = "pt").to(model.device)

outputs = model.generate(
    **inputs, 
    max_new_tokens = 256,
    use_cache = True, 
    do_sample=False, 
    repetition_penalty=1.2)

prediction = tokenizer.decode(
    outputs[0], skip_special_tokens=True).split('\n### 回答')[-1]

11. 推論結果



プロンプト❶:健康になりたい

DPO前

まず、自分の生活習慣や食生活を見直しましょう。運動不足が気になる場合は、ウォーキングなどの軽い運動から始めてみると良いでしょう。また、バランスの取れた食事と十分な睡眠も大切です。特に朝食をしっかり摂ることで、一日のエネルギーレベルを高めることができます。
次に、ストレス管理にも目を向けることが重要です。適度なリラクゼーション法(例えばヨガや瞑想)や、趣味に時間を使ってリフレッシュする方法を見つけてみてください。仕事や学業で忙しいかもしれませんが、定期的に休憩をとることが心身の健康維持には欠かせません。
さらに、専門家によるアドバイスを受けることもおすすめです。医師や栄養士など信頼できる人に相談することで、自分では気づかない問題点を指摘してもらえるかもしれません。必要であれば、サプリメントなども検討してみてください。ただし、自己判断での使用は避けるべきですので、必ず医療機関からの指導のもと行うようにしてください。
最後に、継続的な努力が必要ですが、無理をしない範囲で少しずつ改善していくことを心がけてください。小さな変化でも積み重ねることで大きな成果につながることが多いので、焦らず一歩ずつ進んでいきましょう。

DPO後

まず、自分の生活習慣や食生活を見直しましょう。運動不足が気になる場合は、ウォーキングなどの軽い運動から始めてみてはいかがでしょうか。また、バランスの取れた食事と十分な睡眠も大切です。規則正しい生活リズムを作ることで、体調管理がしやすくなります。
次に、ストレス解消法を見つけてみましょう。趣味に没頭したり、友人との時間を持つことでリフレッシュできます。時には専門家によるカウンセリングを受けることも有効かもしれません。
最後に、定期的な検診で体の状態をチェックしましょう。早期発見・治療のためにも、年一回程度の人間ドック受診などをお勧めします。



プロンプト❷:暇で暇でしようがないのですが、どうしたらいいですか?

DPO前

1. 新しい趣味を始める: 絵を描く、音楽を作る、料理をするなど、新しいことに挑戦してみましょう。
2. ボランティア活動に参加する: 地域のボランティア団体やNPOに登録して、社会貢献をしながら時間をつぶすことができます。
3. オンライン学習プラットフォームを利用する: Coursera, Udemyなどのオンラインコースで、自分の興味のある分野について学びましょう。
4. 読書と執筆: 本を読むだけでなく、自分でも文章を書いてみましょう。ブログを始めたり、短編小説を書いたりするのも良いアイデアです。
5. ゲームをする: 好きなゲームをしてリフレッシュしましょう。ただし、過度なプレイは避けてください。適度な休憩が必要です。
6. 運動をする: 散歩したり、ジョギングしたりして体を動かしましょう。運動することで気分転換になります。
7. 料理教室に行く: プロから直接教えてもらえるので、スキルアップにもなります。
8. 旅行プランを立てる: 近場でもいいので旅行計画を立てて、実際に行ってみるのも楽しいでしょう。

DPO後

1. 新しい趣味を始める: 絵描き、音楽演奏、料理など、新しいことに挑戦してみましょう。
2. ボランティア活動に参加する: 地域のボランティア団体やNPOに登録して、社会貢献をする時間を作るのも良いでしょう。
3. 読書・勉強: 本を読むことや、資格取得のための学習などを通じて、自己成長につながることをするのもおすすめです。


プロンプト❸:読書って意味があるんですか?

DPO前

本を読むことは、知識を得たり、新しい視点や考え方に触れたりするために非常に有意義な活動です。以下に、いくつかの理由から読書が持つ価値について説明します。1. 知識の獲得: 本はさまざまな分野に関する情報を提供してくれます。歴史、科学、技術、文化など多岐にわたるトピックについての理解を深めることができます。これにより、自己成長や問題解決能力も向上する可能性があります。2. 創造性と想像力の刺激: 物語の世界に没入することで、現実とは異なる世界での経験ができます。これは創造力や想像力を豊かにする助けとなります。また、フィクション作品では感情的な感受性も高まります。3. 批判的思考の育成: 読書によって批判的に物事を見る力がつきます。著者の意見だけでなく、その根拠となるデータや事実にも目を向けることで、より深い洞察を持つことができます。4. 多様な価値観へのアクセス: 同じテーマでも多種多様なアプローチがあることがわかるため、自分の視野を広げることができるでしょう。異文化に対する理解力も深まることがあります。5. ストレス軽減効果: 忙しい日常の中でリラックスできる時間を作ることもできます。特に静かで集中しやすい環境では、ストレス解消につながることが多いです。6. 自己啓発・モチベーションアップ: 成功した人々は、しばしば成功体験記などを読んでいます。これらの書籍(←ここで切れる)

DPO後

本を読むことは、知識を得たり、新しい視点や考え方に触れたりするために非常に有意義な活動です。以下に、いくつかの理由で読むことが重要であることを説明します。
1. 知識の獲得: 本はさまざまな分野についての情報、歴史的な出来事、そして科学的発見などを提供してくれます。これにより、読者がより深い理解を持つことができます。
2. 創造性と想像力の刺激: 物語の中で登場人物になりきったり、未知の世界を探検したりすることで、創造力や想像力が豊かになります。これは日常生活にも役立ちます。
3. 批判力・批判能力の向上: 本の内容に対して疑問を持って考えることが求められるため、批判的に考えるスキルが身につきます。このスキルは、他の人々とのコミュニケーションにおいても有用となります。
4. 自己成長: 本から学んだことを実践したり、自分の経験として取り入れることで、自己改善につながることがあります。また、異なる文化背景についての学びも、国際感覚を養う助けとなるでしょう。
5. ストレス軽減効果: 読書にはリラックス効果があり、ストレスレベルを低下させる働きがあります。特に物語などのフィクションでは、現実逃避ができ心のリフレッシュになることもあります。

いかがでしょうか?

DPO前は、総じて文章が「冗長的」で、途中で切れてしまうこともあるのに対し、DPO後は、相対的に文章がコンパクトになっていますね。途中で切れるということもありません。
他にもいくつか試してみましたが、同じような傾向が見られました。
3,000程度の学習データなので評価は難しいですが、これはこれでスゴイなと思いました。

いいなと思ったら応援しよう!

Non
よろしければサポートお願いします! いただいたサポートはクリエイターとしての活動費に使わせていただきます!