見出し画像

UnslothでGRPO


Unslothによってメモリ使用量を抑え、実行時間も短縮してGRPOを実行してみます。
Unslothを使わずにGRPOを実行したバージョンは以下の記事。

この記事で使用するデータセットと報酬関数は上の記事と同じものを使用します。

準備

from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)

from unsloth import is_bfloat16_supported
import torch
from datasets import concatenate_datasets, load_dataset
import dotenv
dotenv.load_dotenv()
import wandb
import os
os.environ["WANDB_PROJECT"] = "grpo"

max_seq_length = 512 # Can increase for longer reasoning traces
prompt_length = 384
lora_rank = 8 # Larger rank = smarter, but slower
model_name = "Rakuten/RakutenAI-2.0-mini-instruct"
ds_path = "./datasets"

今回のモデルはRakutenAI 2.0を使用します。

モデルロード

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_name,
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.6, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = 16,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id

RakutenAI 2.0モデルをそのままUnslothでGRPOを実行しようとすると、以下のようなエラーが発生します。

RuntimeError: Prefix caching is not supported with sliding window. Run with --disable-sliding-window to use prefix caching.

このエラーを解消するには、モデルのconfig.jsonで”sliding_window”の値を以下のようにNullにする必要があります。

vim ~/.cache/huggingface/hub/models--Rakuten--RakutenAI-2.0-mini-instruct/snapshots/3f9c4d5fd99bb51aca44a41a5d3750c0a7a60554/config.json
・・・
  "rope_theta": 100000.0,
  "sliding_window": null,
  "tie_word_embeddings": false,
・・・

プロンプト

システムプロンプトは以下のように設定しておきます。

system_prompt = '''
あなたは優秀なAIアシスタントです。
ユーザーの質問に答えるために、推論を行ってください。
まず推論プロセスを考え、その答えを推論プロセスとともにユーザーに提供してください。
推論のプロセスと答えは、それぞれ<think></think>と<answer></answer> タグで囲ってください。
以下のような形式で回答を提供してください。
<think> ここに推論プロセスを記載してください</think>
<answer> ここに回答を記載してください </answer>
ただし、<answer></answer>には単位も含まない数値のみを記載してください。
'''

報酬関数

import re

def start_think_tag(string):
    tag = "<think>"
    return tag in string

def end_think_tag(string):
    tag = "</think>"
    return tag in string

def start_answer_tag(string):
    tag = "<answer>"
    return tag in string

def end_answer_tag(string):
    tag = "</answer>"
    return tag in string

def thinking_text_length(string):
    pattern = r'<think>(.*?)</think>'
    match = re.search(pattern, string, re.DOTALL)
    if match:
        return len(match.group(1)) / max_seq_length
    return 0

def is_think_format(string):
    pattern = r'<think>(.*?)</think>'
    match = re.findall(pattern, string, re.DOTALL)
    if match:
        if len(match) == 1:
            return True
        return False
    return False

def is_answer_format(string):
    pattern = r'<answer>(.*?)</answer>'
    match = re.findall(pattern, string, re.DOTALL)
    if match:
        if len(match) == 1:
            return True
        return False
    return False

def is_think_answer_format(string):
    pattern = r"<think>(.*?)</think>.<answer>(.*?)</answer>$"
    match = re.fullmatch(pattern, string, re.DOTALL)
    if match:
        try:
            return bool(match.group(1)) and bool(match.group(2))
        except:
            return False
    else:
        return False

def get_format_reward(completions, **kwargs):
    '''
    Format reward
    '''
    rewards = []
    for completion in completions:
        print(completion)
        completion = completion[0]['content']
        r= 0
        r += int(start_think_tag(completion))
        r += int(end_think_tag(completion))
        r += int(start_answer_tag(completion))
        r += int(end_answer_tag(completion))
        r += int(is_think_format(completion))
        r += int(is_answer_format(completion))
        r += int(is_think_answer_format(completion))
        # r += thinking_text_length(completion)
        rewards.append(r)
    return rewards

def extract_answer_text(string):
    pattern = r'<answer>(.*?)</answer>'
    match = re.search(pattern, string, re.DOTALL)
    if match:
        return match.group(1)
    return None

def get_correct_reward(completions, answers, **kwargs):
    '''
    Correct Reward
    '''
    rewards = []
    for completion, ground_truth in zip(completions, answers):
        rewards.append(int(extract_answer_text(completion[0]['content']) == ground_truth) * 10)
    return rewards

データセット

task1_format = '''
{question}
以下の選択肢から回答を選んでください。
{options}
'''

def transfer_dataset(example):
    prompt = example['prompt']
    options = f"{example['A']}, {example['B']}, {example['C']}, {example['D']}"
    task_text = task1_format.format(question=prompt, options=options)
    correct = example[example['Correct']]
    example['prompt'] = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": task_text}      
    ]
    example['answers'] = correct
    return example

csv_files = [
    f'{ds_path}/elementary_mathematics.csv',
    f'{ds_path}/high_school_mathematics.csv',
    f'{ds_path}/college_mathematics.csv'
]
ds1 = load_dataset("csv", data_files=csv_files, split='train')
ds1 = ds1.map(transfer_dataset)

task2_format = '''
{question}
'''

def transfer_dataset_json(example):
    prompt = example['question'].replace('解答:', '')
    task_text = task2_format.format(question=prompt)
    correct = example['answer']
    example['prompt'] = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": task_text}      
    ]
    example['answers'] = correct
    return example

ds2 = load_dataset("json", data_files=f'{ds_path}/test.json', split='train')
ds2 = ds2.map(transfer_dataset_json)

task3_format = '''
{question}
'''

def transfer_dataset_json(example):
    prompt = example['question'].replace('解答:', '')
    task_text = task3_format.format(question=prompt)
    correct = example['answer']
    example['prompt'] = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": task_text}      
    ]

    example['answers'] = correct
    return example

ds3 = load_dataset("csv", data_files=f'{ds_path}/mgsm_ja.tsv', split='train', delimiter="\t")
ds3 = ds3.map(transfer_dataset_json)

ds = concatenate_datasets([ds1, ds2, ds3])

学習設定

学習の設定値はUnslothが提供しているサンプルからほとんどを持ってきています。

"temprature"は0.5として、今回は学習を実行しましたがもっと大きな値で実行した方が学習がよく進みそうな印象でした。

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    use_vllm = True, # use vLLM for fast inference!
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    temperature=0.5,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    bf16 = is_bfloat16_supported(),
    fp16 = not is_bfloat16_supported(),
    per_device_train_batch_size = 8,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 8, # Decrease if out of memory
    max_prompt_length = prompt_length,
    max_completion_length = max_seq_length - prompt_length,
    num_train_epochs = 3, # Set to 1 for a full training run
    max_grad_norm = 1.0,
    report_to = "wandb", # Can use Weights & Biases
    output_dir = "outputs",
)

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[get_format_reward, get_correct_reward],
    args=training_args,
    train_dataset=ds,
)

学習実行・保存

実行や保存は以下の記載の通り。

trainer.train()
wandb.finish()

model.save_pretrained_merged("rakuten2_grpo_lora", tokenizer, save_method = "lora",)

まとめ

RakutenAI 2.0は1.5Bモデルでありながら、元から一定の推論性能を持っていそうでした。そのため、今回のGRPOによる学習では性能向上が確認できませんでした。
性能改善のためには高品質なReasoningデータでSFTを実行したあと、GRPOで推論性能の向上を目指すような手間のかかるやり方を取る必要がありそうですね。

いいなと思ったら応援しよう!