![見出し画像](https://assets.st-note.com/production/uploads/images/172869289/rectangle_large_type_2_3e3c7ade2530d86f83e715469555abf7.jpeg?width=1200)
GRPOを試してみた
今回はGRPO(Group Relative Policy Optimization)を試してみました。GRPOとは、LLMのチューニング手法でDeepSeek R1が学習するための中心となる手法です。
この記事ではアルゴリズムの詳しい説明はしませんが、よさそうな解説記事を載せておきます。
実装
早速、実装について記述していきます。
準備
最初に必要ライブラリのインポートとパラメータの設定を行います。
import os
import torch
from trl.trainer import GRPOConfig, GRPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig
from peft import LoraConfig
from datasets import concatenate_datasets, load_dataset
import dotenv
dotenv.load_dotenv()
import wandb
os.environ["WANDB_PROJECT"] = "grpo"
# 基本設定
max_seq_length = 256
prompt_length = 384
num_proc = 4
random_seed = 3407
model_name = "meta-llama/Llama-3.2-1B-Instruct"
save_path = 'GRPO_Model'
is_4bit = True
is_8bit = False if is_4bit else True
ds_path = "./datasets"
# Llammaプロンプト
prompt_template = '''
<|start_header_id|>system<|end_header_id|>
あなたは優秀なAIアシスタントです。
ユーザーの質問に答えるために、推論を行ってください。
まず推論プロセスを考え、その答えを推論プロセスとともにユーザーに提供してください。
推論のプロセスと答えは、それぞれ<think></think>と<answer></answer> タグで囲ってください。
以下のような形式で回答を提供してください。
<think> ここに推論プロセスを記載してください</think>
<answer> ここに回答を記載してください </answer>
<|eot_id|>
<|start_header_id|>user<|end_header_id|>
{user_input}
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
'''
TRLライブラリにGRPO Trainerが入っているので、それを利用します。
モデルは、Llama3.2 1Bを使用しています。
報酬関数
次に、報酬関数を定義します。
報酬関数はDeepSeekがやっていたように、「回答形式に対する報酬」と「回答の正解/不正解」の2種類を使います。
import re
def start_think_tag(string):
tag = "<think>"
return tag in string
def end_think_tag(string):
tag = "</think>"
return tag in string
def start_answer_tag(string):
tag = "<answer>"
return tag in string
def end_answer_tag(string):
tag = "</answer>"
return tag in string
def thinking_text_length(string):
pattern = r'<think>(.*?)</think>'
match = re.search(pattern, string, re.DOTALL)
if match:
return len(match.group(1)) / max_seq_length
return 0
def is_think_format(string):
pattern = r'<think>(.*?)</think>'
match = re.findall(pattern, string, re.DOTALL)
if match:
if len(match) == 1:
return True
return False
return False
def is_answer_format(string):
pattern = r'<answer>(.*?)</answer>'
match = re.findall(pattern, string, re.DOTALL)
if match:
if len(match) == 1:
return True
return False
return False
def is_think_answer_format(string):
pattern = r"<think>(.*?)</think>.<answer>(.*?)</answer>$"
match = re.fullmatch(pattern, string, re.DOTALL)
if match:
try:
return bool(match.group(1)) and bool(match.group(2))
except:
return False
else:
return False
def get_format_reward(completions, **kwargs):
'''
Format reward
'''
rewards = []
for completion in completions:
print(completion)
r= 0
r += int(start_think_tag(completion))
r += int(end_think_tag(completion))
r += int(start_answer_tag(completion))
r += int(end_answer_tag(completion))
r += int(is_think_format(completion))
r += int(is_answer_format(completion))
r += int(is_think_answer_format(completion))
# r += thinking_text_length(completion)
rewards.append(r)
return rewards
def extract_answer_text(string):
pattern = r'<answer>(.*?)</answer>'
match = re.search(pattern, string, re.DOTALL)
if match:
return match.group(1)
return None
def get_correct_reward(completions, ground_truths, **kwargs):
'''
Correct Reward
'''
rewards = []
for completion, ground_truth in zip(completions, ground_truths):
rewards.append(int(extract_answer_text(completion) == ground_truth) * 10)
return rewards
データセット
データは数学推論タスク用に集められたデータを利用します。
task1_format = '''
{question}
以下の選択肢から回答を選んでください。
{options}
'''
def transfer_dataset(example):
prompt = example['prompt']
options = f"{example['A']}, {example['B']}, {example['C']}, {example['D']}"
task_text = task1_format.format(question=prompt, options=options)
input_text = prompt_template.format(user_input=task_text)
correct = example[example['Correct']]
example['prompt'] = input_text
example['ground_truths'] = correct
return example
csv_files = [
f'{ds_path}/elementary_mathematics.csv',
f'{ds_path}/high_school_mathematics.csv',
f'{ds_path}/college_mathematics.csv'
]
ds1 = load_dataset("csv", data_files=csv_files, split='train')
ds1 = ds1.map(transfer_dataset)
ds1[0]
task2_format = '''
{question}
'''
def transfer_dataset_json(example):
prompt = example['question'].replace('解答:', '')
task_text = task2_format.format(question=prompt)
input_text = prompt_template.format(user_input=task_text)
correct = example['answer']
example['prompt'] = input_text
example['ground_truths'] = correct
return example
ds2 = load_dataset("json", data_files=f'{ds_path}/test.json', split='train')
ds2 = ds2.map(transfer_dataset_json)
ds2[0]
task3_format = '''
{question}
'''
def transfer_dataset_json(example):
prompt = example['question'].replace('解答:', '')
task_text = task3_format.format(question=prompt)
input_text = prompt_template.format(user_input=task_text)
correct = example['answer']
example['prompt'] = input_text
example['ground_truths'] = correct
return example
ds3 = load_dataset("csv", data_files=f'{ds_path}/mgsm_ja.tsv', split='train', delimiter="\t")
ds3 = ds3.map(transfer_dataset_json)
ds3[0]
モデル・学習準備
モデルのロードは4bitに量子化して、transformersで読み込みます。以前の記事でUnslothを使っていたので、今回もUnslothを使いたかったのですが、まだGRPOには適用できないらしくエラーになりました。
quantization_config = BitsAndBytesConfig(
load_in_8bit=is_8bit,
load_in_4bit=is_4bit
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id
target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
]
config = LoraConfig(
r=4,
lora_alpha=16,
target_modules=target_modules,
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM"
)
training_args = GRPOConfig(
output_dir="GRPO_TEST",
learning_rate=1e-4,
logging_steps=1,
num_train_epochs=2,
gradient_accumulation_steps=4,
per_device_train_batch_size=1,
max_prompt_length=prompt_length,
max_completion_length=max_seq_length,
num_generations=4,
remove_unused_columns=False,
temperature=0.9,
optim="adamw_8bit",
lr_scheduler_type="constant",
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[get_format_reward, get_correct_reward],
args=training_args,
train_dataset=ds,
peft_config=config,
)
モデルのチューニングはLoRAで、少量のパラメータのみを最適化します。
num_generationsの値を4としていますが、ここは一つのプロンプトの対していくつの回答文を生成するかを扱うパラメータでメモリ容量に大きく影響します。より大きな値の方がよさそうです。
学習の実行
trainer.train()
wandb.finish()
結果
以下のような質問で結果を確認しました。
queries = [
"あるサッカーチームがサッカーボールを買うのに90円持っている。サッカーボール1個の値段が15円だとすると、チームが購入できるサッカーボールの最大数はいくつか",
"原価3000円の商品に20%の利益をつけて売りました。定価はいくら?",
"-40 ÷ (-8) の商を求めよ。"
]
回答形式はうまく学習できていますが、推論内容はうまくできる場合とできない場合が出ている状態でした。
<think>
90 / 15 = 《90/15=6》6 個のボールを持っていきます。
</think>
<answer>6</answer>
<think>
3000 * (1 + 20/100) = 3000 * 1.2 = 3600
</think>
<answer>3600</answer>
<think>-40 ÷ -8 = 5</think>
<answer>5</answer>