見出し画像

PEFTでLoRAマージしてCTranslate2で遊ぶメモ

このメモを読むと

・ファインチューニング後のモデルを爆速で動かせる

検証環境

・Windows11
・VRAM24GB
・ローカル(Anaconda)
・2023/6/M時点

事前準備

Anacondaを使うメモ|おれっち (note.com)
Gitを使うメモ|おれっち (note.com)

モデルのマージ

PEFTを使うことで手軽にファインチューニングを行うことができます
そして、得られたLoRAモデルとベースモデルを合体させることでマージモデルを作成できます。
マージモデルを使ってCTranslate2で文章爆速生成してみましょう。

すること

 ・LoRA作成
 ・モデルマージ
 ・CTranslate2で文章生成

環境構築

とても簡単です!

1. 仮想環境を作成し、環境切替

conda create -n mergetest python=3.10
activate mergetest

2. 追加パッケージのインストール

pip install datasets accelerate loralib sentencepiece transformers ctranslate2 protobuf==3.20.0
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.39.0-py3-none-any.whl
pip install git+https://github.com/huggingface/peft.git

3. 以下からCUDA Toolkit11.8を導入します。(未導入の場合)

完了です!

PEFTでモデルマージ

LoRAモデル作成

 下記記事の"学習"を行うことで作成できます。

 データセットは以下からお借りしました。

モデルマージ

 好きな名前で下記スクリプトを作成し実行します。

import torch
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

peft_name = "test-rinna-3.6b"   #学習済みadapter_config.jsonのパス指定
output_dir = "test-rinna-3.6b/test-rinna-3.6b-merge"  #マージモデルの出力先

# PEFT(LoRA)の指定
peft_config = PeftConfig.from_pretrained(peft_name)
# ベースモデルの読み込み
model = AutoModelForCausalLM.from_pretrained(
    peft_config.base_model_name_or_path,
    return_dict=True,
    torch_dtype=torch.float16,
)
# Rinnaのトークナイザーでは、「use_fast=False」も必要になる
tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path,use_fast=False)
# PEFT(LoRA)の読み込み
model = PeftModel.from_pretrained(model, peft_name)
# マージモデル作成
merged_model = model.merge_and_unload()
# 出力
merged_model.save_pretrained(output_dir)  
tokenizer.save_pretrained(output_dir)
print(f"Saving to {output_dir}")  
いろいろ生成されます。

これでマージ完了。

CTranslate2で文章生成

モデル変換

 マージ後のモデルをCTranslate2用に変換

import subprocess

base_model = "test-rinna-3.6b/test-rinna-3.6b-merge" #マージモデルの出力先を指定
output_dir = "test-rinna-3.6b/rinnna-gozaru-int8"  #CTranslate2への変換先
quantization_type = "int8"

command = f"ct2-transformers-converter --model {base_model} --quantization {quantization_type} --output_dir {output_dir}"
subprocess.run(command, shell=True)

推論

import ctranslate2
import transformers
import time

base_model = "rinna/japanese-gpt-neox-3.6b-instruction-ppo"
quant_model = "test-rinna-3.6b/rinnna-gozaru-int8"    #CTranslate2への変換先を指定

generator = ctranslate2.Generator(quant_model, device="cuda")
# Rinnaのトークナイザーでは、「use_fast=False」も必要になる
tokenizer = transformers.AutoTokenizer.from_pretrained(base_model, use_fast=False)

# プロンプトを作成する
def prompt(msg):
    p = [
        {"speaker": "ユーザー", "text": msg},
    ]
    p = [f"{uttr['speaker']}: {uttr['text']}" for uttr in p]
    p = "<NL>".join(p)
    p = p + "<NL>" + "システム: "
    return p

# 返信を作成する
def reply(msg):
    tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt(msg),add_special_tokens=False,))

    results = generator.generate_batch(
        [tokens],
        max_length=256,
        sampling_topk=10,
        sampling_temperature=0.9,
        include_prompt_in_result=False,
    )

    text = tokenizer.decode(results[0].sequences_ids[0])
    print("A: " + text)
    return text

if __name__ == "__main__":
    while True:
        msg = input("Q: ")
        start_time = time.time()
        reply(msg)
        end_time = time.time()
        elapsed_time = end_time - start_time
        print(f"Elapsed time: {elapsed_time} seconds")

Q: カレーにジャガイモは入れるべき?
A: いいえ、入れないでござる。
Elapsed time: 0.1907789707183838 seconds

出力

おわり

爆速でござれました。

参考資料

CTranslate2でrinna instructionをquantizeして動かす|if001 (note.com)
RinnaのppoモデルをCTranslate2で高速に動かす (zenn.dev)

いいなと思ったら応援しよう!