RTX3060上でJapanese StableLM Base AlphaをLoRAファインチューニングする

2023/08/12: Windows版のbitsandbytesのインストール方法について追記しました。
2023/08/12: PEFTのインストール方法と、学習後に保存されたLoRAモデルの読み込み方法について追記しました。

導入

先日、Stability AI Japanから日本語の大規模言語モデル(LLM)として、70億パラメータの「Japanese StableLM Alpha 7B」が公開されました。
本記事の執筆時点(2023/8/11)で公開されているオープンな日本語モデルの中で、今回公開されたJapanese StableLM Alphaはベンチマークで最も高い精度を出しており、注目を集めています。

詳細については以下を参照してください
https://ja.stability.ai/blog/japanese-stablelm-alpha

概要

本記事では、Japanese StableLM Base Alpha 7BをRTX3060 (VRAM 12GB) 上で動作させ、LoRAを用いたファインチューニング学習を行います。
Japanese StableLM Alphaは約70億個のパラメータを持ち、通常のfloat (32bit浮動小数点数)型で読み込むと28GBのメモリ容量を消費することになり、RTX3060のVRAMに収まりません。しかし、本モデルはint8型に量子化されたバージョンが公開されており、こちらを利用することで12GBのVRAMに収めることができます。
RTX3060は執筆時点で4万円前後で入手可能なグラフィックボードであり、本記事の手法を用いることで一般家庭でも比較的手軽にLoRA学習が可能になります。

Japanese StableLM Alphaには基盤モデルの「Japanese StableLM Base Alpha 7B」と、対話チューニングされた「Japanese StableLM Instruct Alpha 7B」の二種類のモデルがあり、前者はApache License 2.0で公開されていますが、後者は研究目的の利用に限定されています。本記事では前者のBaseモデルを用います。

基盤モデルはそのまま用いるよりも、特定のタスクに合わせたファインチューニングという学習を行ってから利用するのが普通です。LoRA (Low Rank Adaptation) とは、重み行列全体を訓練する代わりに、低ランクの行列で学習を行うことで効率的にファインチューニングを行う手法です。7Bパラメータのモデルのファインチューニングには通常30GB以上のVRAMが必要となりますが、LoRAを用いることでVRAM消費量を大幅に抑えてファインチューニングを行うことができます。


準備

まず、必要なライブラリをインストールします。
なお、PyTorch (GPU版), CUDA (11.0以上), CUDNNは既にインストールされているものとします。

pip install 'transformers>=4.31.0' datasets sentencepiece 'accelerate>=0.20.3'
pip install git+https://github.com/huggingface/peft.git

int8型でモデルを読み込む際に利用するbitsandbytesライブラリは、windows版のDLLファイルがpipからダウンロードできないため、bitsandbytes-windows-webuiで公開されているWindows版ビルドを利用します。
古いバージョンには保存時にVRAMを過剰に消費する問題があるため、0.41.0以上のバージョンを利用してください。

# windowsの場合
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl

ライブラリのインポート

ここからはpythonのコードを記述します。
初めに、必要なライブラリを読み込みます。
Transformersのバージョンが古いとLlamaTokenizerがインポートできない旨のエラーが表示されることがあります。4.31.0以上のバージョンをインストールしてください。

import torch
from transformers import LlamaTokenizer, AutoModelForCausalLM

モデルの読み込み

続いて、Japanese StableLM Base Alphaのモデルをロードします。int8版では約7GB、fp16版では14GBのサイズがあるので注意してください。

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {device}")

model_id = "stabilityai/japanese-stablelm-base-alpha-7b" # instruction版を使用する場合は "stabilityai/japanese-stablelm-instruct-alpha-7b" に設定してください
load_in = "int8" # "fp32", "fp16", "int8" のいずれかを設定してください

cache_dir = "./model_cache"    # モデルのキャッシュを保存するフォルダ

model_kwargs = {"trust_remote_code": True, "device_map": "auto", "low_cpu_mem_usage": True, "cache_dir": cache_dir}

if load_in == "fp16":
    model_kwargs["variant"] = "fp16"
    model_kwargs["torch_dtype"] = torch.float16
elif load_in == "int8":
    model_kwargs["variant"] = "int8"
    model_kwargs["load_in_8bit"] = True

tokenizer = LlamaTokenizer.from_pretrained("novelai/nerdstash-tokenizer-v1")
model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)      #モデルを読み込む

if load_in != "int8":
    model.to(device)

model.eval()

データセットの読み込み

日本語の質疑応答データセットである「databricks-dolly-15k-ja」を読み込みます。
長い文章はVRAMを多く消費するため、入力トークン数を1024に制限しています。これを超過する長いプロンプトは正しく学習されない可能性があるため、VRAMに余裕がある場合はdataset_max_lengthを大きく(最大2048)することを推奨します。

# データセットを読み込みトークン化する

import datasets

dataset_name = "kunishou/databricks-dolly-15k-ja"
dataset = datasets.load_dataset(dataset_name)

data_max_length = 512 #VRAM消費量を減らすため512トークンに制限

prompt_with_context_format = """The following text is the task instruction and the context for it.
Write a response that satisfies the instruction based on context.

### Instruction:
{instruction}

### Context:
{context}

### Response:
{response}
"""

prompt_no_context_format = """The following text is the task instruction.
Write a response that satisfies the instruction based on context.

### Instruction:
{instruction}

### Response:
{response}
"""

def tokenize(samples):
    prompts = []

    # データセットの instruction 列と input 列と output 列を組み合わせてプロンプトを組み立てます。
    for instruction, input, output in zip(samples["instruction"], samples["input"], samples["output"]):
        if input:
            prompt = prompt_with_context_format.format(instruction=instruction, context=input, response=output)
        else:
            prompt = prompt_no_context_format.format(instruction=instruction, response=output)
        prompts.append(prompt + tokenizer.eos_token)

    result = tokenizer(prompts, padding=False, truncation=True, max_length=data_max_length)
    return result

dataset = dataset.map(lambda samples: tokenize(samples), batched=True)

PEFT学習のための調整

gradient checkpointingという機能を有効にすることでVRAM使用量を節約します。
また、一部のパラメータをfloat32型に変換することで学習を安定させます。

# ベースモデルをフリーズ

for param in model.parameters():
    param.requires_grad = False
    if param.ndim == 1:
        param.data = param.data.to(torch.float32)

# VRAM消費量を節約するための調整

model.gradient_checkpointing_enable()

model.enable_input_require_grads()

class CastOutputToFloat(torch.nn.Sequential):
   def forward(self, x):
      return super().forward(x).to(torch.float32)

model.embed_out = CastOutputToFloat(model.embed_out)

print(model)

学習の設定を行います。
rはLoRA層の中間層の次元数(rank)であり、rを大きくすることで性能が上がりますが、学習時間と必要なVRAM容量が増加します。

# PEFTを用いたLoRA学習の設定

import peft

peft_config = peft.LoraConfig(
    r=32,
    lora_alpha=64,
    target_modules=["query_key_value"],
    lora_dropout=0.05,
    bias="none",
    fan_in_fan_out=False,
    task_type=peft.TaskType.CAUSAL_LM
)
model = peft.get_peft_model(model, peft_config)

model.print_trainable_parameters()

学習を実行

LoRA学習を行います。
学習途中のモデルは.checkpointフォルダ、学習後のモデルはoutputフォルダに保存されます。

from transformers import TrainingArguments, DataCollatorForLanguageModeling, Trainer

training_args = TrainingArguments(
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    warmup_steps=20,
    max_steps=200,
    learning_rate=2e-4,
    fp16=True,
    num_train_epochs=1,
    save_strategy="steps",
    save_steps=50,
    save_total_limit=10,
    output_dir=".checkpoints",
    evaluation_strategy="no",
    logging_dir="logs",
    logging_steps=25,
    gradient_checkpointing=True,
    push_to_hub=False
)

data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

trainer = Trainer(
    model=model,
    train_dataset=dataset["train"],
    args=training_args,
    data_collator=data_collator,
)

model.config.use_cache = False

# 学習を途中から再開する場合はここへチェックポイント名を記入します。
checkpoint = None
# checkpoint = "checkpoint-100"

trainer.train(checkpoint)       # 学習を実行

model.save_pretrained("output")     # outputフォルダに学習後のモデルを保存

学習されたモデルをテスト

学習が終了した後、質問を入力して正しく学習できているか確かめます。

# 学習されたモデルをテストする

input_data = {
    "input": [""],  # 回答時に参考にする入力データ (必須ではない)
    "instruction": ["REST APIとは何ですか?"],  # 好きな質問に置き換えてください
    "output": [""]  # ここは空にする
}

input_tokens = tokenize(input_data)
input_tokens["input_ids"][0] = input_tokens["input_ids"][0][:-1]        # remove eos

model.config.use_cache = False
with torch.autocast("cuda"):

    # seedを設定することで毎回同じ結果が得られる
    seed = 23
    torch.manual_seed(seed)

    tokens = model.generate(
        input_ids=torch.tensor(input_tokens["input_ids"]).to(device=model.device),
        max_new_tokens=128,
        temperature=0.8,
        top_p=0.9,
        do_sample=True,
        pad_token_id=tokenizer.pad_token_id,
    )

    out = tokenizer.decode(tokens[0], skip_special_tokens=False)

print(out)

私の環境では、結果は下のようになりました。

<|startoftext|> The following text is the task instruction.
Write a response that satisfies the instruction based on context.

### Instruction:
REST APIとは何ですか?

### Response:
RESTは、Representational State Transferの略で、インターネットを介してコンピュータにデータや命令を転送するために使用される一連の標準を定義する。 RESTは、インターネットのプロトコルに準拠したWebサービスによって使用される通信プロトコルである。 RESTは、インターネット上のWebアプリケーションの相互運用性を確立するために使用されます。<|endoftext|>

保存したLoRAモデルを読み込む

outputフォルダに保存した学習後のモデルを読み込むには、上記の「モデルの読み込み」の節のコードを実行してベースモデルを読み込んでから、下記のコードを実行します。

import peft

adapter_path = "./output"

model = peft.PeftModel.from_pretrained(model, adapter_path)

参考文献

この記事を執筆するにあたって、以下の記事を参考にしました。

・大規模言語モデル open-calm-7b を Windows かつ GeForce RTX 3060 (12GB) を搭載したローカル PC で動かしてみた
https://qiita.com/yasusun/items/6418f5558ea9993b725b

・mkshingさんのColab用ノートブック
https://twitter.com/mk1stats/status/1689551167284789248

いいなと思ったら応援しよう!