PEFTでLoRAマージしてCTranslate2で遊ぶメモ
このメモを読むと
・ファインチューニング後のモデルを爆速で動かせる
検証環境
・Windows11
・VRAM24GB
・ローカル(Anaconda)
・2023/6/M時点
事前準備
Anacondaを使うメモ|おれっち (note.com)
Gitを使うメモ|おれっち (note.com)
モデルのマージ
PEFTを使うことで手軽にファインチューニングを行うことができます。
そして、得られたLoRAモデルとベースモデルを合体させることでマージモデルを作成できます。
マージモデルを使ってCTranslate2で文章爆速生成してみましょう。
すること
・LoRA作成
・モデルマージ
・CTranslate2で文章生成
環境構築
とても簡単です!
1. 仮想環境を作成し、環境切替
conda create -n mergetest python=3.10
activate mergetest
2. 追加パッケージのインストール
pip install datasets accelerate loralib sentencepiece transformers ctranslate2 protobuf==3.20.0
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.39.0-py3-none-any.whl
pip install git+https://github.com/huggingface/peft.git
3. 以下からCUDA Toolkit11.8を導入します。(未導入の場合)
完了です!
PEFTでモデルマージ
LoRAモデル作成
下記記事の"学習"を行うことで作成できます。
データセットは以下からお借りしました。
モデルマージ
好きな名前で下記スクリプトを作成し実行します。
import torch
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
peft_name = "test-rinna-3.6b" #学習済みadapter_config.jsonのパス指定
output_dir = "test-rinna-3.6b/test-rinna-3.6b-merge" #マージモデルの出力先
# PEFT(LoRA)の指定
peft_config = PeftConfig.from_pretrained(peft_name)
# ベースモデルの読み込み
model = AutoModelForCausalLM.from_pretrained(
peft_config.base_model_name_or_path,
return_dict=True,
torch_dtype=torch.float16,
)
# Rinnaのトークナイザーでは、「use_fast=False」も必要になる
tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path,use_fast=False)
# PEFT(LoRA)の読み込み
model = PeftModel.from_pretrained(model, peft_name)
# マージモデル作成
merged_model = model.merge_and_unload()
# 出力
merged_model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"Saving to {output_dir}")
これでマージ完了。
CTranslate2で文章生成
モデル変換
マージ後のモデルをCTranslate2用に変換
import subprocess
base_model = "test-rinna-3.6b/test-rinna-3.6b-merge" #マージモデルの出力先を指定
output_dir = "test-rinna-3.6b/rinnna-gozaru-int8" #CTranslate2への変換先
quantization_type = "int8"
command = f"ct2-transformers-converter --model {base_model} --quantization {quantization_type} --output_dir {output_dir}"
subprocess.run(command, shell=True)
推論
import ctranslate2
import transformers
import time
base_model = "rinna/japanese-gpt-neox-3.6b-instruction-ppo"
quant_model = "test-rinna-3.6b/rinnna-gozaru-int8" #CTranslate2への変換先を指定
generator = ctranslate2.Generator(quant_model, device="cuda")
# Rinnaのトークナイザーでは、「use_fast=False」も必要になる
tokenizer = transformers.AutoTokenizer.from_pretrained(base_model, use_fast=False)
# プロンプトを作成する
def prompt(msg):
p = [
{"speaker": "ユーザー", "text": msg},
]
p = [f"{uttr['speaker']}: {uttr['text']}" for uttr in p]
p = "<NL>".join(p)
p = p + "<NL>" + "システム: "
return p
# 返信を作成する
def reply(msg):
tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt(msg),add_special_tokens=False,))
results = generator.generate_batch(
[tokens],
max_length=256,
sampling_topk=10,
sampling_temperature=0.9,
include_prompt_in_result=False,
)
text = tokenizer.decode(results[0].sequences_ids[0])
print("A: " + text)
return text
if __name__ == "__main__":
while True:
msg = input("Q: ")
start_time = time.time()
reply(msg)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time} seconds")
おわり
爆速でござれました。
参考資料
・CTranslate2でrinna instructionをquantizeして動かす|if001 (note.com)
・RinnaのppoモデルをCTranslate2で高速に動かす (zenn.dev)