
UnslothでGRPO
Unslothによってメモリ使用量を抑え、実行時間も短縮してGRPOを実行してみます。
Unslothを使わずにGRPOを実行したバージョンは以下の記事。
この記事で使用するデータセットと報酬関数は上の記事と同じものを使用します。
準備
from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)
from unsloth import is_bfloat16_supported
import torch
from datasets import concatenate_datasets, load_dataset
import dotenv
dotenv.load_dotenv()
import wandb
import os
os.environ["WANDB_PROJECT"] = "grpo"
max_seq_length = 512 # Can increase for longer reasoning traces
prompt_length = 384
lora_rank = 8 # Larger rank = smarter, but slower
model_name = "Rakuten/RakutenAI-2.0-mini-instruct"
ds_path = "./datasets"
今回のモデルはRakutenAI 2.0を使用します。
モデルロード
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = model_name,
max_seq_length = max_seq_length,
load_in_4bit = True, # False for LoRA 16bit
fast_inference = True, # Enable vLLM fast inference
max_lora_rank = lora_rank,
gpu_memory_utilization = 0.6, # Reduce if out of memory
)
model = FastLanguageModel.get_peft_model(
model,
r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_alpha = 16,
use_gradient_checkpointing = "unsloth", # Enable long context finetuning
random_state = 3407,
)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
RakutenAI 2.0モデルをそのままUnslothでGRPOを実行しようとすると、以下のようなエラーが発生します。
RuntimeError: Prefix caching is not supported with sliding window. Run with --disable-sliding-window to use prefix caching.
このエラーを解消するには、モデルのconfig.jsonで”sliding_window”の値を以下のようにNullにする必要があります。
vim ~/.cache/huggingface/hub/models--Rakuten--RakutenAI-2.0-mini-instruct/snapshots/3f9c4d5fd99bb51aca44a41a5d3750c0a7a60554/config.json
・・・
"rope_theta": 100000.0,
"sliding_window": null,
"tie_word_embeddings": false,
・・・
プロンプト
システムプロンプトは以下のように設定しておきます。
system_prompt = '''
あなたは優秀なAIアシスタントです。
ユーザーの質問に答えるために、推論を行ってください。
まず推論プロセスを考え、その答えを推論プロセスとともにユーザーに提供してください。
推論のプロセスと答えは、それぞれ<think></think>と<answer></answer> タグで囲ってください。
以下のような形式で回答を提供してください。
<think> ここに推論プロセスを記載してください</think>
<answer> ここに回答を記載してください </answer>
ただし、<answer></answer>には単位も含まない数値のみを記載してください。
'''
報酬関数
import re
def start_think_tag(string):
tag = "<think>"
return tag in string
def end_think_tag(string):
tag = "</think>"
return tag in string
def start_answer_tag(string):
tag = "<answer>"
return tag in string
def end_answer_tag(string):
tag = "</answer>"
return tag in string
def thinking_text_length(string):
pattern = r'<think>(.*?)</think>'
match = re.search(pattern, string, re.DOTALL)
if match:
return len(match.group(1)) / max_seq_length
return 0
def is_think_format(string):
pattern = r'<think>(.*?)</think>'
match = re.findall(pattern, string, re.DOTALL)
if match:
if len(match) == 1:
return True
return False
return False
def is_answer_format(string):
pattern = r'<answer>(.*?)</answer>'
match = re.findall(pattern, string, re.DOTALL)
if match:
if len(match) == 1:
return True
return False
return False
def is_think_answer_format(string):
pattern = r"<think>(.*?)</think>.<answer>(.*?)</answer>$"
match = re.fullmatch(pattern, string, re.DOTALL)
if match:
try:
return bool(match.group(1)) and bool(match.group(2))
except:
return False
else:
return False
def get_format_reward(completions, **kwargs):
'''
Format reward
'''
rewards = []
for completion in completions:
print(completion)
completion = completion[0]['content']
r= 0
r += int(start_think_tag(completion))
r += int(end_think_tag(completion))
r += int(start_answer_tag(completion))
r += int(end_answer_tag(completion))
r += int(is_think_format(completion))
r += int(is_answer_format(completion))
r += int(is_think_answer_format(completion))
# r += thinking_text_length(completion)
rewards.append(r)
return rewards
def extract_answer_text(string):
pattern = r'<answer>(.*?)</answer>'
match = re.search(pattern, string, re.DOTALL)
if match:
return match.group(1)
return None
def get_correct_reward(completions, answers, **kwargs):
'''
Correct Reward
'''
rewards = []
for completion, ground_truth in zip(completions, answers):
rewards.append(int(extract_answer_text(completion[0]['content']) == ground_truth) * 10)
return rewards
データセット
task1_format = '''
{question}
以下の選択肢から回答を選んでください。
{options}
'''
def transfer_dataset(example):
prompt = example['prompt']
options = f"{example['A']}, {example['B']}, {example['C']}, {example['D']}"
task_text = task1_format.format(question=prompt, options=options)
correct = example[example['Correct']]
example['prompt'] = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": task_text}
]
example['answers'] = correct
return example
csv_files = [
f'{ds_path}/elementary_mathematics.csv',
f'{ds_path}/high_school_mathematics.csv',
f'{ds_path}/college_mathematics.csv'
]
ds1 = load_dataset("csv", data_files=csv_files, split='train')
ds1 = ds1.map(transfer_dataset)
task2_format = '''
{question}
'''
def transfer_dataset_json(example):
prompt = example['question'].replace('解答:', '')
task_text = task2_format.format(question=prompt)
correct = example['answer']
example['prompt'] = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": task_text}
]
example['answers'] = correct
return example
ds2 = load_dataset("json", data_files=f'{ds_path}/test.json', split='train')
ds2 = ds2.map(transfer_dataset_json)
task3_format = '''
{question}
'''
def transfer_dataset_json(example):
prompt = example['question'].replace('解答:', '')
task_text = task3_format.format(question=prompt)
correct = example['answer']
example['prompt'] = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": task_text}
]
example['answers'] = correct
return example
ds3 = load_dataset("csv", data_files=f'{ds_path}/mgsm_ja.tsv', split='train', delimiter="\t")
ds3 = ds3.map(transfer_dataset_json)
ds = concatenate_datasets([ds1, ds2, ds3])
学習設定
学習の設定値はUnslothが提供しているサンプルからほとんどを持ってきています。
"temprature"は0.5として、今回は学習を実行しましたがもっと大きな値で実行した方が学習がよく進みそうな印象でした。
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
use_vllm = True, # use vLLM for fast inference!
learning_rate = 5e-6,
adam_beta1 = 0.9,
adam_beta2 = 0.99,
weight_decay = 0.1,
warmup_ratio = 0.1,
temperature=0.5,
lr_scheduler_type = "cosine",
optim = "paged_adamw_8bit",
logging_steps = 1,
bf16 = is_bfloat16_supported(),
fp16 = not is_bfloat16_supported(),
per_device_train_batch_size = 8,
gradient_accumulation_steps = 1, # Increase to 4 for smoother training
num_generations = 8, # Decrease if out of memory
max_prompt_length = prompt_length,
max_completion_length = max_seq_length - prompt_length,
num_train_epochs = 3, # Set to 1 for a full training run
max_grad_norm = 1.0,
report_to = "wandb", # Can use Weights & Biases
output_dir = "outputs",
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[get_format_reward, get_correct_reward],
args=training_args,
train_dataset=ds,
)
学習実行・保存
実行や保存は以下の記載の通り。
trainer.train()
wandb.finish()
model.save_pretrained_merged("rakuten2_grpo_lora", tokenizer, save_method = "lora",)
まとめ
RakutenAI 2.0は1.5Bモデルでありながら、元から一定の推論性能を持っていそうでした。そのため、今回のGRPOによる学習では性能向上が確認できませんでした。
性能改善のためには高品質なReasoningデータでSFTを実行したあと、GRPOで推論性能の向上を目指すような手間のかかるやり方を取る必要がありそうですね。