見出し画像

Unsloth で独自の R1 Reasoningモデルを学習

以下の記事が面白かったので、簡単にまとめました。

Train your own R1 reasoning model with Unsloth


1. はじめに

DeepSeek」の研究では、「R1-Zero」 が「GRPO」(Group Relative Policy Optimization) を使用して、人間のフィードバックなしでより多くの思考時間を割り当てることを自律的に学習したという「aha moment」が明らかになりました。

「UInsloth」では「GRPO」プロセス全体を強化し、「Hugging Face + FA2」よりも80%のVRAMの使用量が削減できました。これにより、「Qwen2.5 1.5B」を使用して、わずか7GBのVRAMで「R1-Zero」の「aha moment」を再現できました。

Llama 3.1 (8B) GRPO Colab notebook

「Phi-4」などの他のモデルを備えたGRPOノートブックについては、以下を参照してください。

2. Unsloth で独自の R1 Reasoningモデルを学習

Unsloth」を使用すると、15GBのVRAMで、「Llama 3.1」(8B)、「Phi-4」(14B)、「Mistral」(7B)、「Qwen2.5」(7B) などの最大15BをReasoningモデルに変換できます。最小要件は7GBのVRAMで、独自のReasoningモデルをローカルで学習できます。

Tiny-Zero」の素晴らしいチームは、「Qwen2.5」(1.5B) で「aha moment」を実現できることを実証しました。ただし、そのためには 2 つの「A100 GPU」(160GB VRAM) が必要でした。「Unsloth」を使用すると、1つの7GB VRAMだけで同じ「aha moment」を実現できます。

以前は、「GRPO」は完全なファインチューニングにのみサポートされていましたが、「QLoRA」と「LoRA」で動作するようにしました。

これは、「DeepSeekのR1蒸留モデル」のファインチューニングや、「Unsloth」がすでにサポートしているチューニングにR1の蒸留データを使用しているわけではないことに注意してください。これは、「GRPO」を使用して標準モデルを本格的なReasoningモデルに変換することです。

3. GRPO + aha moment

「DeepSeek」の研究者は、純粋な強化学習で「R1-Zero」を学習するときに「aha moment」を観察しました。このモデルは、人間の指導や事前定義された指示なしに、最初のアプローチを再評価することで、思考時間を延長することを学びました。

テスト例では、「GRPO」を使用して100ステップで「Phi-4」を学習しただけですが、結果はすでに明確です。「GRPO」のないモデルには「thinkingトークン」がありませんが、「GRPO」で学習されたモデルには「thinkingトークン」があり、正解もあります。

このマジックは、Value関数に依存する「PPO」とは異なり、Value関数を必要とせずに応答を効率的に最適化する強化学習アルゴリズムである「GRPO」を介して再現できます。Unslothのノートブックでは、「GRPO」でモデルを学習し、独自の自己検証と検索能力を自律的に開発することを目指し、ミニ「aha moment」を作成します。

しくみは、次のとおりです。

(1) モデルは応答のグループを生成します。
(2) 各応答は、LLM報酬モデルではなく、いくつかのセット報酬関数によって作成された正確性または別のメトリックに基づいて採点されます。
(3) グループの平均スコアが計算されます。
(4) 各回答のスコアは、グループ平均と比較されます。
(5) このモデルは、より高いスコアリングの応答を支持するように強化されています。

例として、モデルで以下を解決したいと仮定します。

1+1とは何ですか?>> 思考の連鎖/ワークアウト >> 答えは2です。
2+2とは何ですか?>> 思考の連鎖/ワークアウト >> 答えは4です。

もともと、「思考の連鎖/ワークアウト」を埋めるために、大量のデータを収集する必要がありました。しかし、「GRPO」または他の強化学習アルゴリズムは、モデルを誘導して推論能力を自動的に表示し、Reasoningトレースを作成できます。代わりに、優れた報酬機能または検証者を作成する必要があります。「正解した場合、1点を与える」「いくつかの単語がスペルミスの場合、マイナス0.1」などです。プロセスに報酬を与えるために、多くの機能を提供できます。

4. Unsloth の GRPO

報酬が実際に増加するまで、少なくとも300ステップを待ちます。「Unsloth」で 「GRPO」をローカルで使用している場合は、「pip install diffuser」もお願いします。最新バージョンの「vLLM」を使用する必要があります。Colabの例は1時間で学習されたばかりなので、結果は標準以下であることを覚えておいてください。良い結果を得るには、少なくとも12時間学習する必要がありますが、いつでも停止できるため、これは必須ではないことに注意してください。

「thinkingトークン」を正しく生成するには、少なくとも1.5Bを持つモデルに 「GRPO」を適用することをお勧めします。小さいモデルでは正しく生成されない可能性があります。基本モデルを使用している場合は、チャットテンプレートがあることを確認してください。「GRPO」の学習損失追跡は、「Unsloth」に直接組み込まれているため、wandbなどの外部ツールは不要です。

「GRPO」サポートの追加に加えて、「Online DPO」「PPO」「RLOO」もサポートします。詳しくはこちらの投稿そしてブログを参照してください。「Unsloth」の「Online DPO」と「Hugging Face + FA2」のVRAM消費の比較については、以下のグラフを参照してください。

5. Unsloth x vLLM

vLLM」をファインチューニングスタックで直接使用できるようになり、スループットが大幅に向上し、モデルのファインチューニングと推論を同時に実行できるようになりました。1x A100 40GB では、「Unsloth」の「Llama 3.2 3B Instruct」 の「dynamic 4bit quant」で 4000トークン/秒が期待できます。16GB Tesla T4 (無料のColab GPU) では、300トークン/秒が得られます。

また、「vLLM」と「Unsloth」を一緒にロードするときに、魔法のようにメモリ使用量を二重に削除し、「Llama 3.1 8B」で5GB、「Llama 3.2 3B」で3GBを節約できます。「Unsloth」はもともと、「Llama 3.3 70B Instruct」を 1x 48GB GPUでファインチューニングすることができ、「Llama 3.3 70B」は 40GBのVRAMを取ります。二重メモリ使用量を削除しない場合、「Unsloth」と「vLLM」を一緒にロードするときに>= 80GBのVRAMが必要になります。

しかし、「Unsloth」を使用すると、48GB未満のVRAMで1つのパッケージでファインチューニング整し、高速推論の利点を得ることができます。高速推論を使用するには、まず「vllm」をインストールし、fast_inference で「Unsloth」をインスタンス化します。

pip install unsloth vllm
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Llama-3.2-3B-Instruct",
    fast_inference = True,
)
model.fast_generate(["Hello!"])

6. Unsloth での vLLM の調査結果

「vLLM」は、Unsloth Dynamic 4-bitをロードできるようになりました。1.58bit Dynamic R1 GGUFと同様に、特定のレイヤーを4bitに、一部を16bitに動的量子化することで、モデルを小さく保ちながら精度を大幅に向上できることを示しました。

RAMとVRAMの効率、最大スループット (チャンクされたプレフィルトークンの数、最大シーケンス数など) を考慮して、複数のパラメータを自動的に選択します。「vLLM」ではデフォルトで -O3 を有効にし、プレフィックスキャッシュを有効にします。古いGPUでのFlashinferは実際には10%遅いことがわかりました。FP8 KVキャッシュでは10%遅くなりますが、スループットの可能性は2倍になります。

ディスクからロードする代わりに、状態辞書を解析することで、「vLLM」にLoRAをロードできるようになりました。これにより、「GRPO」の学習の実行速度が1.5倍になります。現在研究が進められている分野は、「vLLM 」で LoRAアダプタを直接編集することです (方法はまだわかりません)。現在、不要なGPUデータの移動が行われているため、これにより速度が大幅に向上します。

「vLLM」では、特にバッチ生成中に、VRAMスパイクがランダムに発生します。メモリスパイクを削減するために、バッチ生成機能を追加しました。

関連



いいなと思ったら応援しよう!