見出し画像

Long-context GRPO

以下の記事が面白かったので、簡単にまとめました。

Long-context GRPO


1. Long-context GRPO

「Qwen2.5」(1.5B) がわずか 5GB VRAM で、独自のReasoningモデルを学習できるようになりました。2週間前のGRPOリリースの7GBから減少しました。

現在、より長いコンテキストの長さを達成することは、「GRPO」の最大の課題の1つです。新しく派生した「Unsloth Efficient GRPOアルゴリズム」は、「Flash Attention 2」(FA2) を利用するものであっても、他のすべてのGRPO LoRA/QLoRA実装と比較して、90%少ないVRAMで、コンテキスト長を10倍長くすることができます。

「TRL + FA2」を使用した「GRPO」セットアップでは、20Kコンテキスト長の「Llama 3.1」(8B) の学習には510.8GBのVRAMが必要です。しかし、UnslothのVRAMが90%削減されたため、要件は同じセットアップでわずか54.3GBに短縮されます。

10倍長いコンテキストを無料のGRPOノートブックで試すことができます。

ColabのLlama 3.1(8B)

読むことを強くお勧めしますガイドを読むことを推奨します。「Phi-4」のような他のモデルのGRPOノートブックはこちらにあります。楽しんだらgithub.com/unslothai/unslothにスターをつけてください。

2. 長いコンテキストでVRAMを90%少なくする

UnslothでGRPOを行う場合、複数のトリックを使用して、「Flash Attention 2」の標準実装と比較して、VRAMの使用率を90%以上削減します。たとえば、プロンプトごとに8生成の20Kコンテキストの長さでは、Unslothは「Llama 3.1 8B」で54.3GBのVRAMしか使用していませんが、標準実装では510.8GB(Unslothでは90%少ない)を使用します。

(1) 「GRPO」用の新しいメモリ効率リニアアルゴリズムは、メモリ使用量を8倍以上削減します。これは68.5GBのメモリを削減しますが、num_generations = 8および20Kコンテキストの長さのtorch.compileの助けを借りて実際に高速になります。

(2) 少し前にリリースした「Unsloth gradient checkpointing」を活用しています。このアルゴリズムは、中間アクティベーションをシステムRAMに非同期でオフロードし、わずか1%の低速化を実現すします。これにより、num_generations = 8が必要なので、なんと372GBのVRAMを節約することができます。中間勾配の累積によって、このメモリ使用量をさらに削減することができます。

(3) Unslothはまた、他のパッケージの実装とは異なり、基礎となる推論エンジン(vLLM)と同じGPU / CUDAメモリスペースを使用します。これにより、16GBのVRAMが削減されます。

典型的な標準のGRPO実装では、GRPO損失を計算するために、サイズ (8, 20K)のロジットを2つ作成する必要があります。これには、2 * 2 バイト * 8 (生成数) * 20K (コンテキスト長) * 128256 (語彙サイズ) = VRAM で 78.3GB が必要です。

Unslothは、長いコンテキストGRPOのメモリ使用量を8倍に削減するため、20Kのコンテキストの長さで追加のVRAMで9.8GBだけ必要です。

また、16bitのKVキャッシュからも必要です。「Llama 3.1 8B」は32層あり、KとVの両方のサイズは1024です。したがって、20Kコンテキスト長のメモリ使用量 = 2 * 2バイト * 32レイヤー * 20Kコンテキスト長 * 1024 = バッチあたり2.5GBになります。vLLMのバッチサイズを8に設定しますが、VRAMを節約するために計算するために1のままにします。それ以外の場合は、KVキャッシュに20GBが必要です。

3. Unsloth Efficient GRPOアルゴリズム

Unslothでは、Horace He’s linear cross entropyの実装からヒントを得て、GRPOで機能させることに成功しました。そして、いくつかの驚くべき発見をしました。

・参照GRPO実装は、前方KL発散ではなく、逆KL発散を使用します。
・自動混合精度スケーリングメカニズムを備えたfloat16混合精度 (およびfloat8) で線形クロスエントロピーをナイニーに実装すると、適切に処理しないと壊れます。
・GRPO損失の実施に関して、主に逆KL発散の定式化に関して、他の癖を発見しました。

4. GRPOの数学 と 見つかった問題

「GRPO」は、2024年2月から2024年4月にかけて「DeepSeek」が発表した数学論文で初めて紹介されました。「DeepSeek」はその後、論文で言及したように、「DeepSeek R1」を作成する際に「GRPO」を活用しました。

「Hugging Face」のTRLのGRPO実装を活用します。TRLの実装が実行していることがわかります。

ここで、逆KL発散 (前方KL発散ではない) を利用します。ベータは0.04に設定されたスケーリング係数で、Aはすべての報酬関数を考慮して得られた利点です。Qは新しい学習モデルで、Pは元の参照モデルです。

興味深いことに、実装は逆KL発散を次のように計算します。

しかし、これは実際に正しいでしょうか。まず、それを導出し、次のような用語を収集しようとします。

つまり、実装にはQ (新しい分布項) の乗算が欠落している可能性があるということでしょうか。しかし、これは、最初に「GRPO」を紹介したDeepSeek Mathの論文に見られるように正しいようです (page 14 および Schulman's blog)。また、逆KL用語の偏りのない推定器は、実際には追加のQ用語を必要としないと述べています。ブログでは次のように述べています。

また、興味深いことに、次のことがわかりました。

torch.exp(q - q.detach()) * advantages.unsqueeze(1)

使用されていますが、どれが1に評価されるべきでしょうか。
実際にこれが必要であることがわかりました。 オートグラッドエンジンが勾配を正しく伝播していないようです。

そのため、4つの実験を行います。

・参照実装を介して通常のGRPOを行う (赤い線)
・分離コードの削除 (青い線)
・前述したように、追加用語付きのフルリバースKL (黄色の線)
・代わりにKL発散を前進させる (緑の線)

一般的に、切り離すことは間違いなくすべての学習を中断するので、そのままにしておく必要があります。これには、おそらくさらに調査が必要です。他のすべての実装は似ているように見えますか?異なる効果を見るために、モデルを長く実行する必要があるかもしれません。

すべての実装で、logsumexpトリックも利用します。

5. GRPOのフルロギング

すべての報酬機能の完全なログの詳細も提供しています。以前は、合計された報酬関数自体のみを示しました。

また、「GRPO」にパッチを当てるために関数を呼び出す必要もなくなりました。つまり、上部からこれを削除します (自動的に行います)。

from unsloth import PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)

6. vLLM推論オプション

「vLLM」に FP8 KVキャッシュを使用できるようになりました。これにより、新しい GPU (RTX 3090、A100 以降) で KVキャッシュスペースの使用量が2倍減少します。

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.6, # Reduce if out of memory
    float8_kv_cache = True, # Enable float8 KV cache
)

「vLLM」でmin_p = 0.1、または他のサンプリングパラメータを使用する場合は、「vLLM」のSamplingParams引数で何かを渡すこともできます。

max_prompt_length = 256
from trl import GRPOConfig, GRPOTrainer
from unsloth import vLLMSamplingParams
vllm_sampling_params = vLLMSamplingParams(
    min_p = 0.1,
    seed = 3407,
    ...
)
training_args = GRPOConfig(
    ...
    vllm_sampling_params = vllm_sampling_params,
    temperature = 1.5,
)

7. その他のアップデート

7-1. vLLM で Unsloth Dynamic 4bit を直接実行

「vLLM」で直接動的量子化を実行して推論することができます。こちらでUnslothの動的量子化で標準の4bitよりも精度を大幅に向上させる方法を紹介しています。

7-2. PerplexityのR1-1776の実行

R1-1776 Dynamic GGUF」は、推論能力を維持しながらすべての検閲を削除する「DeepSeek-R1」のファインチューニングモデルです。自分のデバイスでローカルに実行できます。

7-3. GitHubユニバースインタビュー

2024年10月のGitHubユニバースでのインタビューの動画が公開されました。



いいなと思ったら応援しよう!