見出し画像

Google Colab + trl で LINE の japanese-large-lm のQLoRA ファインチューニングを試す

「Google Colab」+「trl」で LINEの「japanese-large-lm」のQLoRA ファインチューニングを試したので、まとめました。


1. trl と ござるデータセット

trl」の「SFTTrainer」と、「ござるデータセット」(bbz662bbz/databricks-dolly-15k-ja-gozarinnemon) を使ってQLoRAファインチューニングに挑戦してみます。

「trl」は「artidoro/qlora」と比べて設定が多くて大変ですが、SFT後の「RLHF」や「DPO」などの追加学習も可能です。

2. Colabでの実行

Colabでの実行手順は、次のとおりです。

(1) パッケージのインストール。

# パッケージのインストール
!pip install -q accelerate==0.21.0 peft==0.4.0 bitsandbytes==0.40.2 transformers==4.31.0 trl==0.4.7
!pip install sentencepiece

(2) パッケージのインポート。

# パッケージのインポート
import os
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer

(4) データセットの読み込み。
今回のモデルはベースモデルで、学習済みのInstructionはないので、独自定義しました。

・入力なし

指示:
{指示}

応答:
{応答}

・入力あり

指示:
{指示}

入力:
{入力}

応答:
{応答}

# データセットの読み込み
dataset = load_dataset("bbz662bbz/databricks-dolly-15k-ja-gozarinnemon", split="train")

# プロンプトテンプレートの準備
def generate_prompt(data_point):
    if data_point["input"]:
        result = f"""指示:
{data_point["instruction"]}

入力:
{data_point["input"]}

応答:
{data_point["output"]}"""
    else:
        result = f"""指示:
{data_point["instruction"]}

応答:
{data_point["output"]}"""
    return result

# テキスト列の追加
def add_text(example):
    example["text"] = generate_prompt(example)
    del example["index"]
    del example["category"]
    del example["instruction"]
    del example["input"]
    del example["output"]
    return example
dataset = dataset.map(add_text)
print(dataset)
print(dataset[0]["text"])
Dataset({
    features: ['text'],
    num_rows: 1502
})
指示:
RELXはどのような株価指数に属しているのですか?

入力:
RELX plc(発音:レルエックス)は、英国ロンドンに本社を置く英国[2]の多国籍情報・分析企業です。科学・技術・医療情報および分析、法律情報および分析、意思決定ツールの提供、展示会の開催などの事業を展開しています。1993年、イギリスの書籍・雑誌出版社であるリード・インターナショナルとオランダの科学出版社であるエルゼビアの合併により誕生した会社です。

同社は上場企業であり、ロンドン証券取引所、アムステルダム証券取引所、ニューヨーク証券取引所で株式を取引しています(ティッカーシンボル:ロンドン:REL、アムステルダム:REN、ニューヨーク:RELX).FTSE100指数、Financial Times Global 500、Euronext 100指数の構成銘柄の一つです。

応答:
我、りんえもんは思う。 RELX plcは、FTSE 100、Financial Times Global 500、Euronext 100の各インデックスを構成していますでござる。知らんけど。

(4) モデルの準備。
「line-corporation/japanese-large-lm-3.6b」を4bit量子化で読み込みます。

# 量子化パラメータ
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,  # 4bitベースモデルの有効化
    bnb_4bit_quant_type="nf4",  # 量子化種別 (fp4 or nf4)
    bnb_4bit_compute_dtype=torch.float16,  # 4bitベースモデルのdtype (float16 or bfloat16)
    bnb_4bit_use_double_quant=False,  # 4bitベースモデルのネストされた量子化の有効化 (二重量子化)
)

# モデルの準備
model_name = "line-corporation/japanese-large-lm-3.6b"
model = AutoModelForCausalLM.from_pretrained(
    model_name,  # モデル名
    quantization_config=bnb_config,  # 量子化パラメータ
    device_map={"": 0}  # モデル全体をGPU0にロード
)
model.config.use_cache = False  # キャッシュ (学習時はFalse)

(5) トークナイザーの準備。
EOSを学習させるため、トークナイザーで以下を指定しました。

# トークナイザーの準備
tokenizer = AutoTokenizer.from_pretrained(
    model_name,  # モデル名
    use_fast=False,  # Fastトークナイザーの有効化
    add_eos_token=True,  # データへのEOSの追加を指示
    trust_remote_code=True
)


・japanese-large-lmのスペシャルトークン

・bos_token : <s> , 1
・eos_token : </s> , 2
・unk_token : <unk> , 0
・pad_token : <pad> , 3

(6) 学習の実行。
今回は、500ステップ学習します。

# LoRAパラメータ
peft_config = LoraConfig(
    r=64,  # LoRAアテンションの次元
    lora_alpha=16,  # LoRAスケーリングのAlphaパラメータ
    lora_dropout=0.1,  # LoRA レイヤーのドロップアウト確率
    bias="none",  # LoRAのバイアス種別 ("none","all", "lora_only")
    task_type="CAUSAL_LM",  # タスク種別
    target_modules=["dense_4h_to_h", "dense", "dense_h_to_4h", "query_key_value"]
)

# 学習パラメータ
training_arguments = TrainingArguments(
    output_dir="./results",  # 出力ディレクトリ
    fp16=True,  # fp16学習の有効化 (T4:True,A100:False)
    bf16=False,  # bf16学習の有効化 (T4:False,A100:True)
    max_steps=500,  # 学習ステップ数
    per_device_train_batch_size=4,  # 学習用のGPUあたりのバッチサイズ
    gradient_accumulation_steps=1,  # 勾配を蓄積するための更新ステップの数
    optim="paged_adamw_32bit",  # オプティマイザ
    learning_rate=2e-4,  # 初期学習率 (AdamW オプティマイザー)
    lr_scheduler_type="cosine",  # 学習率スケジュール
    max_grad_norm=0.3,  # 最大法線勾配 (勾配クリッピング)
    warmup_ratio=0.03,  # 線形ウォームアップのステップ比率 (0から学習率まで)
    weight_decay=0.001,  # bias/LayerNormウェイトを除く全レイヤーに適用するウェイト減衰
    save_steps=0,  # 何ステップ毎にチェックポイントを保存するか
    logging_steps=25,  # 何ステップ毎にログを記録するか
    group_by_length=True,  # シーケンスを同じ長さのバッチにグループ化 (メモリ節約して学習速度が大幅アップ)
    report_to="tensorboard"  # レポート
)

# SFTパラメータ
trainer = SFTTrainer(
    model=model,  # モデル
    tokenizer=tokenizer,  # トークナイザー
    train_dataset=dataset,  # データセット
    dataset_text_field="text",  # データセットのtext列
    peft_config=peft_config,  # PEFTパラメータ
    args=training_arguments,  # 学習パラメータ
)

# モデルの学習
trainer.train()
trainer.model.save_pretrained("./results")

T4で18分ほどで学習完了します。Llama 2よりLossの下げ幅は安定してませんが、Instructionの書式も新規に学習してせいかも。

「./results」にQLoRAモデルが出力されています。

(7) 推論の実行。
質問すると「ござる」で返し、正しい位置でEOS(</s>)になることを確認します。

# プロンプトの準備
prompt = """指示:
日本の首都は? 
 
応答:
"""

# 推論の実行
input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
output_ids = model.generate(
    input_ids.to(device=model.device),
    max_length=100,
    temperature=0.7,
    do_sample=True,
)
output = tokenizer.decode(output_ids.tolist()[0][input_ids.size(1):])
print(output)
我、りんえもんは思う。 首都は東京でござる。知らんけど。</s>

3. Colabでの推論

QLoRAモデルを読み込んで推論する手順は、次のとおりです。

(1) トークナイザーとモデルの準備。

# モデルの準備
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map={"": 0}
)
model = PeftModel.from_pretrained(
    base_model, 
    "./results"
)

# トークナイザーの準備
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    use_fast=False,
    add_eos_token=True,
    trust_remote_code=True
)

(2) 推論の実行。

# プロンプトの準備
prompt = """指示:
日本の国歌は?

応答:
"""

# 推論の実行
input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
output_ids = model.generate(
    input_ids=input_ids.to(device=model.device),
    max_length=100,
    temperature=0.7,
    do_sample=True,
)
output = tokenizer.decode(output_ids.tolist()[0][input_ids.size(1):])
print(output)
我、りんえもんは思う。 君が代は、日本の国歌でござる。知らんけど。</s>



いいなと思ったら応援しよう!