
WSL2でunslothのGPROトレーニングを試してみる
「DeepSeek-R1 の推論を自分のローカル デバイスで再現できるように」「わずか7GBのVRAMでアハ体験を」とのことなので、UnslothのGRPO(Group Relative Policy Optimizatin)トレーニングを試してみます。
今回は Phi-4 (14B)で試してみます。
You can now reproduce DeepSeek-R1's reasoning on your own local device!
— Unsloth AI (@UnslothAI) February 6, 2025
Experience the "Aha" moment with just 7GB VRAM.
Unsloth reduces GRPO training memory use by 80%.
15GB VRAM can transform Llama-3.1 (8B) & Phi-4 (14B) into reasoning models.
Blog: https://t.co/pjvgXOeHZQ pic.twitter.com/10n7OBetkJ
使用するPCはドスパラさんの「GALLERIA UL9C-R49」。スペックは
・CPU: Intel® Core™ i9-13900HX Processor
・Mem: 64 GB
・GPU: NVIDIA® GeForce RTX™ 4090 Laptop GPU(16GB)・GPU: NVIDIA® GeForce RTX™ 4090 (24GB) ※外付け
・OS: Ubuntu22.04 on WSL2(Windows 11)
です。
1. 準備
venvを作ります。
python3 -m venv unsloth-grpo
cd $_
source bin/activate
2. 試してみる
公開されているColabを見に行くとコマンドがあるので、確認しながら進めましょう。
(1) インストール
vllmと連携することでスピードアップ!大幅メモリ削減!らしく。
pip install unsloth vllm
pip install --upgrade pillow
pip install diffusers
pip install git+https://github.com/huggingface/trl.git@e95f9fb74a3c3647b86f251b7e230ec51c64b72b
(2) Unslothの設定
GRPOにパッチを適用するためにPatchFastRLをあらかじめ呼び出しておくとのこと。
そして、Phi-4 (14B) をロードするに際してのパラメータ設定を行います。
from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)
from unsloth import is_bfloat16_supported
import torch
max_seq_length = 512 # Can increase for longer reasoning traces
lora_rank = 16 # Larger rank = smarter, but slower
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/Phi-4",
max_seq_length = max_seq_length,
load_in_4bit = True, # False for LoRA 16bit
fast_inference = True, # Enable vLLM fast inference
max_lora_rank = lora_rank,
gpu_memory_utilization = 0.6, # Reduce if out of memory
)
model = FastLanguageModel.get_peft_model(
model,
r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = ["gate_proj", "up_proj", "down_proj",],
lora_alpha = lora_rank,
use_gradient_checkpointing = "unsloth", # Enable long context finetuning
random_state = 3407,
)
モデルロード時のパラメータ gpu_memory_utilization の値が 0.7 だと、以下のようにOOMが発生するため、0.6 に変更しています。
**Step 3: Determine the score Ahmed needs on the final to exceed Emily's average
[rank0]: Traceback (most recent call last):
[rank0]: File "<stdin>", line 1, in <module>
[rank0]: File "/mnt/data/shoji_noguchi/venv/unsloth-grpo/lib/python3.10/site-packages/transformers/trainer.py", line 2171, in train
[rank0]: return inner_training_loop(
[rank0]: File "<string>", line 382, in _fast_inner_training_loop
[rank0]: File "<string>", line 68, in _unsloth_training_step
[rank0]: File "/mnt/data/shoji_noguchi/venv/unsloth-grpo/lib/python3.10/site-packages/accelerate/accelerator.py", line 2246, in backward
[rank0]: loss.backward(**kwargs)
[rank0]: File "/mnt/data/shoji_noguchi/venv/unsloth-grpo/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward
[rank0]: File "/mnt/data/shoji_noguchi/venv/unsloth-grpo/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
[rank0]: File "/mnt/data/shoji_noguchi/venv/unsloth-grpo/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
[rank0]: RuntimeError: CUDA driver error: unknown error
>>>
この時点でVRAM使用量は、14,738MB。
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA GeForce RTX 4090 On | 00000000:0C:00.0 Off | Off |
| 30% 35C P8 35W / 450W | 14738MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
(3) データ準備
データ準備とすべての報酬関数には @willccbb を活用しています。独自の関数を指定することも可能とのこと。
import re
from datasets import load_dataset, Dataset
# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
def extract_xml_answer(text: str) -> str:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip()
# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
data = data.map(lambda x: { # type: ignore
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': extract_hash_answer(x['answer'])
}) # type: ignore
return data # type: ignore
dataset = get_gsm8k_questions()
# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def int_reward_func(completions, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
def strict_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def soft_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def count_xml(text) -> float:
count = 0.0
if text.count("<reasoning>\n") == 1:
count += 0.125
if text.count("\n</reasoning>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1])*0.001
if text.count("\n</answer>") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
return count
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
どんな報酬関数なのか、一応読み解いておきます。
このコードは数学の問題解答システムのための報酬関数(reward function)を実装したもの。モデルの出力が望ましい形式と内容を持つように誘導するために使用されています。
データセットの準備と形式
load_datasetを使用してGSM8K(Grade School Math 8K)というデータセットを読み込む
システムプロンプトは<reasoning>と<answer>というXML形式で回答を求める
回答抽出関数
extract_xml_answer : XML形式から回答部分を抽出
extract_hash_answer : ####で区切られた形式から回答を抽出
複数の報酬関数の実装
correctness_reward_func :
予測された回答が正解と一致するかチェック
一致する場合は2.0、そうでない場合は0.0を返す
int_reward_func :
回答が整数かどうかをチェック
整数の場合は0.5、そうでない場合は0.0を返す
strict_format_reward_func :
厳密なXML形式に従っているかチェック
正規表現を使用して形式を検証
soft_format_reward_func :
より緩やかなXML形式のチェック
厳密な改行などは要求しない
xmlcount_reward_func :
XML要素の数をカウント
各要素(<reasoning>、</reasoning>など)の存在で0.125点を加算
余分なテキストがある場合はペナルティとして減点
単なる正解・不正解だけでなく、回答形式の適切さも評価しています。問題を解くだけでなく、理解しやすい形式で回答を提示することも学習させているようです。
(4) モデルのトレーニング
GRPOトレーナーのための設定です。
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
use_vllm = True, # use vLLM for fast inference!
learning_rate = 5e-6,
adam_beta1 = 0.9,
adam_beta2 = 0.99,
weight_decay = 0.1,
warmup_ratio = 0.1,
lr_scheduler_type = "cosine",
optim = "paged_adamw_8bit",
logging_steps = 1,
bf16 = is_bfloat16_supported(),
fp16 = not is_bfloat16_supported(),
per_device_train_batch_size = 1,
gradient_accumulation_steps = 1, # Increase to 4 for smoother training
num_generations = 6, # Decrease if out of memory
max_prompt_length = 256,
max_completion_length = 200,
# num_train_epochs = 1, # Set to 1 for a full training run
max_steps = 100,
save_steps = 250,
max_grad_norm = 0.1,
report_to = "none", # Can use Weights & Biases
output_dir = "outputs",
)
そこからのGRPOTrainerの実行です。reward_funcs配列に先ほど定義した報酬関数が5つ指定されているのが分かるかと思います。
trainer = GRPOTrainer(
model = model,
processing_class = tokenizer,
reward_funcs = [
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func,
],
args = training_args,
train_dataset = dataset,
)
trainer.train()
学習中のVRAM使用量は 17,568 MB。gpu_memory_utilizationの値、0.7でも足りそうなんですけどね…。
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA GeForce RTX 4090 On | 00000000:0C:00.0 Off | Off |
| 63% 72C P2 384W / 450W | 17568MiB / 24564MiB | 98% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
30分足らずで学習が完了しました。
-------------------- Question:
Ahmed and Emily are having a contest to see who can get the best grade in the class. There have been 9 assignments and Ahmed has a 91 in the class. Emily has a 92. The final assignment is worth the same amount as all the other assignments. Emily got a 90 on the final assignment. What is the minimum grade Ahmed needs to get to beat Emily if all grades are whole numbers?
Answer:
100
Response:
To solve this problem, we need to determine the minimum grade Ahmed needs on the final assignment to achieve a higher average than Emily.
**Step 1: Calculate Emily's total score.**
Emily's current average is 92 over 9 assignments. Therefore, her total score for these assignments is:
\[
92 \times 9 = 828
\]
Emily scored 90 on the final assignment, so her total score including the final assignment is:
\[
828 + 90 = 918
\]
Emily's new average, including the final assignment, is:
\[
\frac{918}{10} = 91.8
\]
**Step 2: Calculate Ahmed's current total score.**
Ahmed's current average is 91 over 9 assignments. Therefore, his total score for these assignments is:
\[
91 \times 9 = 819
\]
**Step 3: Determine the score Ahmed needs on the final to exceed Emily's average
Extracted:
To solve this problem, we need to determine the minimum grade Ahmed needs on the final assignment to achieve a higher average than Emily.
**Step 1: Calculate Emily's total score.**
Emily's current average is 92 over 9 assignments. Therefore, her total score for these assignments is:
\[
92 \times 9 = 828
\]
Emily scored 90 on the final assignment, so her total score including the final assignment is:
\[
828 + 90 = 918
\]
Emily's new average, including the final assignment, is:
\[
\frac{918}{10} = 91.8
\]
**Step 2: Calculate Ahmed's current total score.**
Ahmed's current average is 91 over 9 assignments. Therefore, his total score for these assignments is:
\[
91 \times 9 = 819
\]
**Step 3: Determine the score Ahmed needs on the final to exceed Emily's average
{'loss': 0.0, 'grad_norm': 0.15737351775169373, 'learning_rate': 5.000000000000001e-07, 'rewards/xmlcount_reward_func': 0.0625, 'rewards/soft_format_reward_func': 0.0, 'rewards/strict_format_reward_func': 0.0, 'rewards/int_reward_func': 0.0, 'rewards/correctness_reward_func': 0.0, 'reward': 0.0625, 'reward_std': 0.06846532225608826, 'completion_length': 200.0, 'kl': 0.0, 'epoch': 0.0}
1%|▉ | 1/100 [00:18<31:19, 18.99s/it]
(snip)
-------------------- Question:
Felix is chopping down trees in his backyard. For every 13 trees he chops down he needs to get his axe resharpened. It cost him $5 to sharpen his axe. If he spends $35 on axe sharpening, at least how many trees has he chopped down?
Answer:
91
Response:
<reasoning>
To determine the minimum number of trees Felix has chopped down, we need to calculate how many times he has had his axe resharpened based on the total amount he has spent on sharpening.
1. Each axe sharpening costs $5.
2. Felix spends $35 in total on sharpening.
3. We calculate the number of times the axe is sharpened by dividing the total amount spent by the cost per sharpening:
\[
\text{Number of sharpenings} = \frac{$35}{$5} = 7
\]
4. Felix needs to get his axe resharpened after chopping down every 13 trees. So, each sharpening indicates that he has chopped down 13 trees.
5. To find the minimum number of trees chopped down, multiply the number of sharpenings by the number of trees per sharpening:
\[
\text{Minimum number of trees} = 7 \times
Extracted:
<reasoning>
To determine the minimum number of trees Felix has chopped down, we need to calculate how many times he has had his axe resharpened based on the total amount he has spent on sharpening.
1. Each axe sharpening costs $5.
2. Felix spends $35 in total on sharpening.
3. We calculate the number of times the axe is sharpened by dividing the total amount spent by the cost per sharpening:
\[
\text{Number of sharpenings} = \frac{$35}{$5} = 7
\]
4. Felix needs to get his axe resharpened after chopping down every 13 trees. So, each sharpening indicates that he has chopped down 13 trees.
5. To find the minimum number of trees chopped down, multiply the number of sharpenings by the number of trees per sharpening:
\[
\text{Minimum number of trees} = 7 \times
{'loss': 0.0, 'grad_norm': 0.07222171127796173, 'learning_rate': 0.0, 'rewards/xmlcount_reward_func': 0.0014999981503933668, 'rewards/soft_format_reward_func': 0.0, 'rewards/strict_format_reward_func': 0.0, 'rewards/int_reward_func': 0.0833333358168602, 'rewards/correctness_reward_func': 0.3333333432674408, 'reward': 0.4181666970252991, 'reward_std': 0.9484413862228394, 'completion_length': 198.83334350585938, 'kl': 0.00023041400709189475, 'epoch': 0.01}
{'train_runtime': 1782.9937, 'train_samples_per_second': 0.056, 'train_steps_per_second': 0.056, 'train_loss': 2.3081666255784228e-05, 'epoch': 0.01}
100%|███████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [29:42<00:00, 17.83s/it]
TrainOutput(global_step=100, training_loss=2.3081666255784228e-05, metrics={'train_runtime': 1782.9937, 'train_samples_per_second': 0.056, 'train_steps_per_second': 0.056, 'total_flos': 0.0, 'train_loss': 2.3081666255784228e-05})
学習を開始させた時点で、grpo_trainer_lora_model/adapter_config.json ファイルが生成され、
{
"alpha_pattern": {},
"auto_mapping": null,
"base_model_name_or_path": "unsloth/phi-4-bnb-4bit",
"bias": "none",
"eva_config": null,
"exclude_modules": null,
"fan_in_fan_out": false,
"inference_mode": false,
"init_lora_weights": true,
"layer_replication": null,
"layers_pattern": null,
"layers_to_transform": null,
"loftq_config": {},
"lora_alpha": 16,
"lora_bias": false,
"lora_dropout": 0,
"megatron_config": null,
"megatron_core": "megatron.core",
"modules_to_save": null,
"peft_type": "LORA",
"r": 16,
"rank_pattern": {},
"revision": null,
"target_modules": [
"up_proj",
"gate_proj",
"down_proj"
],
"task_type": "CAUSAL_LM",
"use_dora": false,
"use_rslora": false
100イテレーション目のチェックポイントの各種データファイルがこちら。
$ ls -l outputs/checkpoint-100/
total 270444
-rw-r--r-- 1 shoji_noguchi shoji_noguchi 5096 Feb 8 11:34 README.md
-rw-r--r-- 1 shoji_noguchi shoji_noguchi 737 Feb 8 11:34 adapter_config.json
-rw-r--r-- 1 shoji_noguchi shoji_noguchi 176979000 Feb 8 11:34 adapter_model.safetensors
-rw-r--r-- 1 shoji_noguchi shoji_noguchi 916646 Feb 8 11:34 merges.txt
-rw-r--r-- 1 shoji_noguchi shoji_noguchi 90141690 Feb 8 11:34 optimizer.pt
-rw-r--r-- 1 shoji_noguchi shoji_noguchi 14244 Feb 8 11:34 rng_state.pth
-rw-r--r-- 1 shoji_noguchi shoji_noguchi 1064 Feb 8 11:34 scheduler.pt
-rw-r--r-- 1 shoji_noguchi shoji_noguchi 570 Feb 8 11:34 special_tokens_map.json
-rw-r--r-- 1 shoji_noguchi shoji_noguchi 7153430 Feb 8 11:34 tokenizer.json
-rw-r--r-- 1 shoji_noguchi shoji_noguchi 17990 Feb 8 11:34 tokenizer_config.json
-rw-r--r-- 1 shoji_noguchi shoji_noguchi 55763 Feb 8 11:34 trainer_state.json
-rw-r--r-- 1 shoji_noguchi shoji_noguchi 5688 Feb 8 11:34 training_args.bin
-rw-r--r-- 1 shoji_noguchi shoji_noguchi 1612637 Feb 8 11:34 vocab.json
(5) 推論
まずは、GRPO学習なし(lora_request指定なし)の状態での推論です。
text = tokenizer.apply_chat_template([
{"role" : "user", "content" : "Which is bigger? 9.11 or 9.9?"},
], tokenize = False, add_generation_prompt = True)
from vllm import SamplingParams
sampling_params = SamplingParams(
temperature = 0.8,
top_p = 0.95,
max_tokens = 1024,
)
output = model.fast_generate(
[text],
sampling_params = sampling_params,
lora_request = None,
)[0].outputs[0].text
output
実行結果がこちら。
Processed prompts: 0%| | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts: 100%|██████████████████████████| 1/1 [00:01<00:00, 1.61s/it, est. speed input: 13.01 toks/s, output: 50.80 toks/s]
>>>
>>> output
'9.11 is bigger than 9.9. When comparing numbers with the same number of whole parts, you compare the decimal parts. In this case, 0.11 is greater than 0.9 when considering their positions after the decimal point. To further illustrate, 9.9 is equivalent to 9.90, and 9.11 is greater than 9.90.'
9.11 is bigger than 9.9. When comparing numbers with the same number of whole parts, you compare the decimal parts. In this case, 0.11 is greater than 0.9 when considering their positions after the decimal point. To further illustrate, 9.9 is equivalent to 9.90, and 9.11 is greater than 9.90.
続いて、GRPO学習後の推論です。その前に、LoRAを保存しておきます。
model.save_lora("grpo_saved_lora")
カレントディレクトリにgrpo_saved_loraディレクトリが作成され、3ファイル保存されます。
$ ls -l grpo_saved_lora/
total 86444
-rw-r--r-- 1 shoji_noguchi shoji_noguchi 5096 Feb 8 11:37 README.md
-rw-r--r-- 1 shoji_noguchi shoji_noguchi 737 Feb 8 11:37 adapter_config.json
-rw-r--r-- 1 shoji_noguchi shoji_noguchi 88505400 Feb 8 11:37 adapter_model.safetensors
では、このLoRAを使用して推論しましょう。
text = tokenizer.apply_chat_template([
{"role" : "system", "content" : SYSTEM_PROMPT},
{"role" : "user", "content" : "Which is bigger? 9.11 or 9.9?"},
], tokenize = False, add_generation_prompt = True)
from vllm import SamplingParams
sampling_params = SamplingParams(
temperature = 0.8,
top_p = 0.95,
max_tokens = 1024,
)
output = model.fast_generate(
text,
sampling_params = sampling_params,
lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text
output
先ほどの素の推論と異なる箇所は2つあります。
保存したLoRAをlora_requestパラメータに指定している
apply_chat_templateメソッドでsystem roleでSYSTEM_PROMPTが指定されている。これは「(3) データ準備」において定義された以下の値。
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
さて、推論の結果は・・・
Processed prompts: 100%|██████████████████████████| 1/1 [00:03<00:00, 3.61s/it, est. speed input: 13.00 toks/s, output: 54.23 toks/s]
>>>
>>> output
'<reasoning>\nTo determine which number is larger between 9.11 and 9.9, we need to compare the numbers digit by digit, starting from the leftmost digit.\n\n1. Compare the whole number part of both numbers:\n - The whole number part of 9.11 is 9.\n - The whole number part of 9.9 is also 9.\n Since the whole number parts are equal, we move to the decimal part.\n\n2. Compare the tenths place:\n - The tenths digit of 9.11 is 1.\n - The tenths digit of 9.9 is 9.\n Since 1 is less than 9, we can already determine that 9.11 is smaller than 9.9 without needing to compare further digits.\n\nTherefore, 9.9 is larger than 9.11.\n</reasoning>\n\n<answer>\n9.9\n</answer>'
※読みやすいように \nで改行入れています。
<reasoning>
To determine which number is larger between 9.11 and 9.9, we need to compare the numbers digit by digit, starting from the leftmost digit.
1. Compare the whole number part of both numbers:
- The whole number part of 9.11 is 9.
- The whole number part of 9.9 is also 9.
Since the whole number parts are equal, we move to the decimal part.
2. Compare the tenths place:
- The tenths digit of 9.11 is 1.
- The tenths digit of 9.9 is 9.
Since 1 is less than 9, we can already determine that 9.11 is smaller than 9.9 without needing to compare further digits.
Therefore, 9.9 is larger than 9.11.
</reasoning>
<answer>
9.9
</answer>
アハ!