見出し画像

WSL2でunslothのGPROトレーニングを試してみる

「DeepSeek-R1 の推論を自分のローカル デバイスで再現できるように」「わずか7GBのVRAMでアハ体験を」とのことなので、UnslothのGRPO(Group Relative Policy Optimizatin)トレーニングを試してみます。

今回は Phi-4 (14B)で試してみます。

使用するPCはドスパラさんの「GALLERIA UL9C-R49」。スペックは
・CPU: Intel® Core™ i9-13900HX Processor
・Mem: 64 GB
・GPU: NVIDIA® GeForce RTX™ 4090 Laptop GPU(16GB)
・GPU: NVIDIA® GeForce RTX™ 4090 (24GB) ※外付け
・OS: Ubuntu22.04 on WSL2(Windows 11)
です。


1. 準備

venvを作ります。

python3 -m venv unsloth-grpo
cd $_
source bin/activate

2. 試してみる

公開されているColabを見に行くとコマンドがあるので、確認しながら進めましょう。

(1) インストール

vllmと連携することでスピードアップ!大幅メモリ削減!らしく。

pip install unsloth vllm
pip install --upgrade pillow

pip install diffusers
pip install git+https://github.com/huggingface/trl.git@e95f9fb74a3c3647b86f251b7e230ec51c64b72b

(2) Unslothの設定

GRPOにパッチを適用するためにPatchFastRLをあらかじめ呼び出しておくとのこと。
そして、Phi-4 (14B) をロードするに際してのパラメータ設定を行います。

from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)

from unsloth import is_bfloat16_supported
import torch
max_seq_length = 512 # Can increase for longer reasoning traces
lora_rank = 16 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Phi-4",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.6, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["gate_proj", "up_proj", "down_proj",],
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

モデルロード時のパラメータ gpu_memory_utilization の値が 0.7 だと、以下のようにOOMが発生するため、0.6 に変更しています。

**Step 3: Determine the score Ahmed needs on the final to exceed Emily's average
[rank0]: Traceback (most recent call last):
[rank0]:   File "<stdin>", line 1, in <module>
[rank0]:   File "/mnt/data/shoji_noguchi/venv/unsloth-grpo/lib/python3.10/site-packages/transformers/trainer.py", line 2171, in train
[rank0]:     return inner_training_loop(
[rank0]:   File "<string>", line 382, in _fast_inner_training_loop
[rank0]:   File "<string>", line 68, in _unsloth_training_step
[rank0]:   File "/mnt/data/shoji_noguchi/venv/unsloth-grpo/lib/python3.10/site-packages/accelerate/accelerator.py", line 2246, in backward
[rank0]:     loss.backward(**kwargs)
[rank0]:   File "/mnt/data/shoji_noguchi/venv/unsloth-grpo/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward
[rank0]:   File "/mnt/data/shoji_noguchi/venv/unsloth-grpo/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
[rank0]:   File "/mnt/data/shoji_noguchi/venv/unsloth-grpo/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
[rank0]: RuntimeError: CUDA driver error: unknown error
>>>

この時点でVRAM使用量は、14,738MB。

+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 4090        On  |   00000000:0C:00.0 Off |                  Off |
| 30%   35C    P8             35W /  450W |   14738MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

(3) データ準備

データ準備とすべての報酬関数には @willccbb を活用しています。独自の関数を指定することも可能とのこと。

import re
from datasets import load_dataset, Dataset

# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

どんな報酬関数なのか、一応読み解いておきます。

このコードは数学の問題解答システムのための報酬関数(reward function)を実装したもの。モデルの出力が望ましい形式と内容を持つように誘導するために使用されています。

  1. データセットの準備と形式

    • load_datasetを使用してGSM8K(Grade School Math 8K)というデータセットを読み込む

    • システムプロンプトは<reasoning>と<answer>というXML形式で回答を求める

  2. 回答抽出関数

    • extract_xml_answer : XML形式から回答部分を抽出

    • extract_hash_answer : ####で区切られた形式から回答を抽出

  3. 複数の報酬関数の実装

    1. correctness_reward_func :

      • 予測された回答が正解と一致するかチェック

      • 一致する場合は2.0、そうでない場合は0.0を返す

    2. int_reward_func : 

      • 回答が整数かどうかをチェック

      • 整数の場合は0.5、そうでない場合は0.0を返す

    3. strict_format_reward_func :

      • 厳密なXML形式に従っているかチェック

      • 正規表現を使用して形式を検証

    4. soft_format_reward_func :

      • より緩やかなXML形式のチェック

      • 厳密な改行などは要求しない

    5. xmlcount_reward_func :

      • XML要素の数をカウント

      • 各要素(<reasoning>、</reasoning>など)の存在で0.125点を加算

      • 余分なテキストがある場合はペナルティとして減点

単なる正解・不正解だけでなく、回答形式の適切さも評価しています。問題を解くだけでなく、理解しやすい形式で回答を提示することも学習させているようです。

(4) モデルのトレーニング

GRPOトレーナーのための設定です。

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    use_vllm = True, # use vLLM for fast inference!
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    bf16 = is_bfloat16_supported(),
    fp16 = not is_bfloat16_supported(),
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 6, # Decrease if out of memory
    max_prompt_length = 256,
    max_completion_length = 200,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 100,
    save_steps = 250,
    max_grad_norm = 0.1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)

そこからのGRPOTrainerの実行です。reward_funcs配列に先ほど定義した報酬関数が5つ指定されているのが分かるかと思います。

trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
    ],
    args = training_args,
    train_dataset = dataset,
)
trainer.train()

学習中のVRAM使用量は 17,568 MB。gpu_memory_utilizationの値、0.7でも足りそうなんですけどね…。

+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 4090        On  |   00000000:0C:00.0 Off |                  Off |
| 63%   72C    P2            384W /  450W |   17568MiB /  24564MiB |     98%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

30分足らずで学習が完了しました。

-------------------- Question:
Ahmed and Emily are having a contest to see who can get the best grade in the class. There have been 9 assignments and Ahmed has a 91 in the class. Emily has a 92. The final assignment is worth the same amount as all the other assignments. Emily got a 90 on the final assignment. What is the minimum grade Ahmed needs to get to beat Emily if all grades are whole numbers?
Answer:
100
Response:
To solve this problem, we need to determine the minimum grade Ahmed needs on the final assignment to achieve a higher average than Emily.

**Step 1: Calculate Emily's total score.**

Emily's current average is 92 over 9 assignments. Therefore, her total score for these assignments is:

\[
92 \times 9 = 828
\]

Emily scored 90 on the final assignment, so her total score including the final assignment is:

\[
828 + 90 = 918
\]

Emily's new average, including the final assignment, is:

\[
\frac{918}{10} = 91.8
\]

**Step 2: Calculate Ahmed's current total score.**

Ahmed's current average is 91 over 9 assignments. Therefore, his total score for these assignments is:

\[
91 \times 9 = 819
\]

**Step 3: Determine the score Ahmed needs on the final to exceed Emily's average
Extracted:
To solve this problem, we need to determine the minimum grade Ahmed needs on the final assignment to achieve a higher average than Emily.

**Step 1: Calculate Emily's total score.**

Emily's current average is 92 over 9 assignments. Therefore, her total score for these assignments is:

\[
92 \times 9 = 828
\]

Emily scored 90 on the final assignment, so her total score including the final assignment is:

\[
828 + 90 = 918
\]

Emily's new average, including the final assignment, is:

\[
\frac{918}{10} = 91.8
\]

**Step 2: Calculate Ahmed's current total score.**

Ahmed's current average is 91 over 9 assignments. Therefore, his total score for these assignments is:

\[
91 \times 9 = 819
\]

**Step 3: Determine the score Ahmed needs on the final to exceed Emily's average
{'loss': 0.0, 'grad_norm': 0.15737351775169373, 'learning_rate': 5.000000000000001e-07, 'rewards/xmlcount_reward_func': 0.0625, 'rewards/soft_format_reward_func': 0.0, 'rewards/strict_format_reward_func': 0.0, 'rewards/int_reward_func': 0.0, 'rewards/correctness_reward_func': 0.0, 'reward': 0.0625, 'reward_std': 0.06846532225608826, 'completion_length': 200.0, 'kl': 0.0, 'epoch': 0.0}
  1%|▉                                                                                                | 1/100 [00:18<31:19, 18.99s/it]

(snip)

-------------------- Question:
Felix is chopping down trees in his backyard. For every 13 trees he chops down he needs to get his axe resharpened. It cost him $5 to sharpen his axe. If he spends $35 on axe sharpening, at least how many trees has he chopped down?
Answer:
91
Response:
<reasoning>
To determine the minimum number of trees Felix has chopped down, we need to calculate how many times he has had his axe resharpened based on the total amount he has spent on sharpening.

1. Each axe sharpening costs $5.
2. Felix spends $35 in total on sharpening.
3. We calculate the number of times the axe is sharpened by dividing the total amount spent by the cost per sharpening:
   \[
   \text{Number of sharpenings} = \frac{$35}{$5} = 7
   \]

4. Felix needs to get his axe resharpened after chopping down every 13 trees. So, each sharpening indicates that he has chopped down 13 trees.

5. To find the minimum number of trees chopped down, multiply the number of sharpenings by the number of trees per sharpening:
   \[
   \text{Minimum number of trees} = 7 \times
Extracted:
<reasoning>
To determine the minimum number of trees Felix has chopped down, we need to calculate how many times he has had his axe resharpened based on the total amount he has spent on sharpening.

1. Each axe sharpening costs $5.
2. Felix spends $35 in total on sharpening.
3. We calculate the number of times the axe is sharpened by dividing the total amount spent by the cost per sharpening:
   \[
   \text{Number of sharpenings} = \frac{$35}{$5} = 7
   \]

4. Felix needs to get his axe resharpened after chopping down every 13 trees. So, each sharpening indicates that he has chopped down 13 trees.

5. To find the minimum number of trees chopped down, multiply the number of sharpenings by the number of trees per sharpening:
   \[
   \text{Minimum number of trees} = 7 \times
{'loss': 0.0, 'grad_norm': 0.07222171127796173, 'learning_rate': 0.0, 'rewards/xmlcount_reward_func': 0.0014999981503933668, 'rewards/soft_format_reward_func': 0.0, 'rewards/strict_format_reward_func': 0.0, 'rewards/int_reward_func': 0.0833333358168602, 'rewards/correctness_reward_func': 0.3333333432674408, 'reward': 0.4181666970252991, 'reward_std': 0.9484413862228394, 'completion_length': 198.83334350585938, 'kl': 0.00023041400709189475, 'epoch': 0.01}
{'train_runtime': 1782.9937, 'train_samples_per_second': 0.056, 'train_steps_per_second': 0.056, 'train_loss': 2.3081666255784228e-05, 'epoch': 0.01}
100%|███████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [29:42<00:00, 17.83s/it]
TrainOutput(global_step=100, training_loss=2.3081666255784228e-05, metrics={'train_runtime': 1782.9937, 'train_samples_per_second': 0.056, 'train_steps_per_second': 0.056, 'total_flos': 0.0, 'train_loss': 2.3081666255784228e-05})

学習を開始させた時点で、grpo_trainer_lora_model/adapter_config.json ファイルが生成され、

{
  "alpha_pattern": {},
  "auto_mapping": null,
  "base_model_name_or_path": "unsloth/phi-4-bnb-4bit",
  "bias": "none",
  "eva_config": null,
  "exclude_modules": null,
  "fan_in_fan_out": false,
  "inference_mode": false,
  "init_lora_weights": true,
  "layer_replication": null,
  "layers_pattern": null,
  "layers_to_transform": null,
  "loftq_config": {},
  "lora_alpha": 16,
  "lora_bias": false,
  "lora_dropout": 0,
  "megatron_config": null,
  "megatron_core": "megatron.core",
  "modules_to_save": null,
  "peft_type": "LORA",
  "r": 16,
  "rank_pattern": {},
  "revision": null,
  "target_modules": [
    "up_proj",
    "gate_proj",
    "down_proj"
  ],
  "task_type": "CAUSAL_LM",
  "use_dora": false,
  "use_rslora": false

100イテレーション目のチェックポイントの各種データファイルがこちら。

$ ls -l outputs/checkpoint-100/
total 270444
-rw-r--r-- 1 shoji_noguchi shoji_noguchi      5096 Feb  8 11:34 README.md
-rw-r--r-- 1 shoji_noguchi shoji_noguchi       737 Feb  8 11:34 adapter_config.json
-rw-r--r-- 1 shoji_noguchi shoji_noguchi 176979000 Feb  8 11:34 adapter_model.safetensors
-rw-r--r-- 1 shoji_noguchi shoji_noguchi    916646 Feb  8 11:34 merges.txt
-rw-r--r-- 1 shoji_noguchi shoji_noguchi  90141690 Feb  8 11:34 optimizer.pt
-rw-r--r-- 1 shoji_noguchi shoji_noguchi     14244 Feb  8 11:34 rng_state.pth
-rw-r--r-- 1 shoji_noguchi shoji_noguchi      1064 Feb  8 11:34 scheduler.pt
-rw-r--r-- 1 shoji_noguchi shoji_noguchi       570 Feb  8 11:34 special_tokens_map.json
-rw-r--r-- 1 shoji_noguchi shoji_noguchi   7153430 Feb  8 11:34 tokenizer.json
-rw-r--r-- 1 shoji_noguchi shoji_noguchi     17990 Feb  8 11:34 tokenizer_config.json
-rw-r--r-- 1 shoji_noguchi shoji_noguchi     55763 Feb  8 11:34 trainer_state.json
-rw-r--r-- 1 shoji_noguchi shoji_noguchi      5688 Feb  8 11:34 training_args.bin
-rw-r--r-- 1 shoji_noguchi shoji_noguchi   1612637 Feb  8 11:34 vocab.json

(5) 推論

まずは、GRPO学習なし(lora_request指定なし)の状態での推論です。

text = tokenizer.apply_chat_template([
    {"role" : "user", "content" : "Which is bigger? 9.11 or 9.9?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,
)[0].outputs[0].text

output

実行結果がこちら。

Processed prompts:   0%|                                    | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts: 100%|██████████████████████████| 1/1 [00:01<00:00,  1.61s/it, est. speed input: 13.01 toks/s, output: 50.80 toks/s]
>>>
>>> output
'9.11 is bigger than 9.9. When comparing numbers with the same number of whole parts, you compare the decimal parts. In this case, 0.11 is greater than 0.9 when considering their positions after the decimal point. To further illustrate, 9.9 is equivalent to 9.90, and 9.11 is greater than 9.90.'

9.11 is bigger than 9.9. When comparing numbers with the same number of whole parts, you compare the decimal parts. In this case, 0.11 is greater than 0.9 when considering their positions after the decimal point. To further illustrate, 9.9 is equivalent to 9.90, and 9.11 is greater than 9.90.

Phi-4 without LoRA

続いて、GRPO学習後の推論です。その前に、LoRAを保存しておきます。

model.save_lora("grpo_saved_lora")

カレントディレクトリにgrpo_saved_loraディレクトリが作成され、3ファイル保存されます。

$ ls -l grpo_saved_lora/
total 86444
-rw-r--r-- 1 shoji_noguchi shoji_noguchi     5096 Feb  8 11:37 README.md
-rw-r--r-- 1 shoji_noguchi shoji_noguchi      737 Feb  8 11:37 adapter_config.json
-rw-r--r-- 1 shoji_noguchi shoji_noguchi 88505400 Feb  8 11:37 adapter_model.safetensors

では、このLoRAを使用して推論しましょう。

text = tokenizer.apply_chat_template([
    {"role" : "system", "content" : SYSTEM_PROMPT},
    {"role" : "user", "content" : "Which is bigger? 9.11 or 9.9?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text

output

先ほどの素の推論と異なる箇所は2つあります。

  • 保存したLoRAをlora_requestパラメータに指定している

  • apply_chat_templateメソッドでsystem roleでSYSTEM_PROMPTが指定されている。これは「(3) データ準備」において定義された以下の値。

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

さて、推論の結果は・・・

Processed prompts: 100%|██████████████████████████| 1/1 [00:03<00:00,  3.61s/it, est. speed input: 13.00 toks/s, output: 54.23 toks/s]
>>>
>>> output
'<reasoning>\nTo determine which number is larger between 9.11 and 9.9, we need to compare the numbers digit by digit, starting from the leftmost digit.\n\n1. Compare the whole number part of both numbers:\n   - The whole number part of 9.11 is 9.\n   - The whole number part of 9.9 is also 9.\n   Since the whole number parts are equal, we move to the decimal part.\n\n2. Compare the tenths place:\n   - The tenths digit of 9.11 is 1.\n   - The tenths digit of 9.9 is 9.\n   Since 1 is less than 9, we can already determine that 9.11 is smaller than 9.9 without needing to compare further digits.\n\nTherefore, 9.9 is larger than 9.11.\n</reasoning>\n\n<answer>\n9.9\n</answer>'

※読みやすいように \nで改行入れています。

<reasoning>
To determine which number is larger between 9.11 and 9.9, we need to compare the numbers digit by digit, starting from the leftmost digit.

1. Compare the whole number part of both numbers:
- The whole number part of 9.11 is 9.
- The whole number part of 9.9 is also 9.
Since the whole number parts are equal, we move to the decimal part.

2. Compare the tenths place:
- The tenths digit of 9.11 is 1.
- The tenths digit of 9.9 is 9.
Since 1 is less than 9, we can already determine that 9.11 is smaller than 9.9 without needing to compare further digits.

Therefore, 9.9 is larger than 9.11.
</reasoning>

<answer>
9.9
</answer>

Phi-4 with LoRA

アハ!

いいなと思ったら応援しよう!