rinna-3.6Bをオリジナル小説でLoRAファインチューニングしてみた【RTX3060 (VRAM 12GB)】
動作確認のために、お試しでやってみました。
概要
背景
AITuberを含めた創作活動への活用のためにrinna-3.6Bでのファインチューニングを勉強したかったのですが、せっかくなら持ってるRTX3060を使ってローカルでやりたいと思っていました。
偉大なる先駆者の方々によって方法が開拓されていたので、ありがたく参考にさせていただいた次第です。
本記事でやったこと
・ローカルのRTX3060 (VRAM 12GB)でrinna-3.6BのLoRAファインチューニングを実行。
・オリジナルの小説をデータセットとして利用。文の続きを書けるか試す。
参考記事
以下の記事を参考にさせていただきました。ほとんどそのままです。環境構築など、詳細な手法はこれらの記事を参照してください。
実装
環境構築
上記の参考記事そのままです。
データセットの準備
1)筆者のオリジナルの小説である『ミューズ・クロニクル ―第十七学芸課は眠らない―』(約16万字)を使用しました。(古い下手くそな文章ですが、長編はこれくらいしか書いたことがないので……)
2)本文のみを抽出してテキストファイルに保存。
3)下記コードにてjsonに保存。今回は、5行分を入力として、続きの5行を出力するようにデータを整形しました。
import pandas as pd
import json
with open('./data/N9764DK_text.txt', encoding='utf8') as f:
raw_text_lines = f.readlines()
# 空行を削除
text_lines = []
for line in raw_text_lines:
if line != "\n":
text_lines.append(line)
result = []
instruction_text = ""
output_text = ""
instruction_num = 5 # 入力する小説の行数
output_num = 5 # 出力する小説の行数
multiline_num = instruction_num + output_num
instruction_count = 0
instruction_flg = True
for i in range(len(text_lines)):
if i % multiline_num == 0:
formatted = {
"input": instruction_text,
"completion": output_text
}
result.append(formatted)
instruction_text = ""
output_text = ""
instruction_flg = True
instruction_text += text_lines[i]
else:
if instruction_flg:
instruction_count += 1
instruction_text += text_lines[i]
if instruction_count == instruction_num-1:
instruction_flg = False
instruction_count = 0
else:
output_text += text_lines[i]
with open('./data/formatted_muse_chronicle.json', 'w', encoding='utf-8') as f:
json.dump(result, f, indent=4)
ファインチューニング
下記コードを実行。上記の参考記事をほとんどそのまま使わせていただいております。
モデルは、japanese-gpt-neox-3.6bを使用。
メモリ使用量は最大7.9GB程度。使用率は100%が体感7割くらいで、時々数十~数%に落ちる感じ。30エポックで1時間半程度かかりました。
import os
import torch
import torch.nn as nn
import bitsandbytes as bnb
import transformers
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
model_name = "rinna/japanese-gpt-neox-3.6b"
dataset = "./data/formatted_muse_chronicle.json"
peft_name = "lora-rinna-3.6b"
output_dir = "lora-rinna-3.6b-results"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
CUTOFF_LEN = 256
def tokenize(prompt, tokenizer):
result = tokenizer(
prompt,
truncation=True,
max_length=CUTOFF_LEN,
padding=False,
)
return {
"input_ids": result["input_ids"],
"attention_mask": result["attention_mask"],
}
# データセット
import json
with open(dataset, "r", encoding='utf-8') as f:
data = json.load(f)
def generate_prompt(data_point):
result = f"""### 指示:
{data_point["input"]}
### 回答:
{data_point["completion"]}
"""
# 改行→<NL>
result = result.replace('\n', '<NL>')
return result
train_dataset = []
val_dataset = []
for i in range(len(data)):
if i % 5 == 0:
x = tokenize(generate_prompt(data[i]), tokenizer)
val_dataset.append(x)
else:
x = tokenize(generate_prompt(data[i]), tokenizer)
train_dataset.append(x)
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
model_name,
load_in_8bit=True,
device_map="auto",
)
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType
lora_config = LoraConfig(
r= 8,
lora_alpha=16,
target_modules=["query_key_value"],
lora_dropout=0.05,
bias="none",
task_type=TaskType.CAUSAL_LM
)
model = prepare_model_for_int8_training(model)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
eval_steps = 200
save_steps = 200
logging_steps = 20
trainer = transformers.Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=val_dataset,
args=transformers.TrainingArguments(
num_train_epochs=30,
learning_rate=3e-4,
logging_steps=logging_steps,
evaluation_strategy="steps",
save_strategy="steps",
eval_steps=eval_steps,
save_steps=save_steps,
output_dir=output_dir,
save_total_limit=3,
push_to_hub=False,
auto_find_batch_size=True
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False
trainer.train()
model.config.use_cache = True
trainer.model.save_pretrained(peft_name)
推論の実行
下記コードを実行。こちらも上記の参考記事ほぼそのままです。複数文を出力するために、最後のEOSトークンまでデコードしています。
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "rinna/japanese-gpt-neox-3.6b"
peft_name = "lora-rinna-3.6b"
output_dir = "lora-rinna-3.6b-results"
model = AutoModelForCausalLM.from_pretrained(
model_name,
load_in_8bit=True,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = PeftModel.from_pretrained(
model,
peft_name,
# device_map="auto"
)
model.eval()
def generate_prompt(data_point):
if data_point["input"]:
result = f"""### 指示:
{data_point["instruction"]}
### 入力:
{data_point["input"]}
### 回答:
"""
else:
result = f"""### 指示:
{data_point["instruction"]}
### 回答:
"""
# 改行→<NL>
result = result.replace('\n', '<NL>')
return result
def generate(instruction, input=None, maxTokens=256) -> str:
prompt = generate_prompt({'instruction': instruction, 'input': input})
input_ids = tokenizer(prompt,
return_tensors="pt",
truncation=True,
add_special_tokens=False).input_ids.cuda()
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=maxTokens,
do_sample=True,
temperature=0.7,
top_p=0.75,
top_k=40,
no_repeat_ngram_size=2,
)
outputs = outputs[0].tolist()
# 最後のEOSトークンまでデコード
if tokenizer.eos_token_id in outputs:
eos_list = [i for i, x in enumerate(outputs) if x == tokenizer.eos_token_id]
decoded = tokenizer.decode(outputs[:eos_list[len(eos_list) - 1]])
sentinel = "### 回答:"
sentinelLoc = decoded.find(sentinel)
if sentinelLoc >= 0:
result = decoded[sentinelLoc + len(sentinel):]
return result.replace("<NL>", "\n") # <NL>→改行
else:
return 'Warning: Expected prompt template to be emitted. Ignoring output.'
else:
return 'Warning: no <eos> detected ignoring output'
query = """「なら、スカーフが何色だったか分かるかい?」
「……それは、そのー、つまり、優雅で繊細な色合いでしたね」
「私は何色だったかと聞いているんだよ」
「まぁ、それは言うだけ野暮ってもんじゃないですか」
そこにハルが口を挟んだ。"""
print(query)
for i in range(3):
print("**********")
print(generate(query))
print("**********")
結果
データセットとした小説内に含まれる文を入力にしました。ちなみに、この文は以下のように続きます。
この場面では、子供がいるかのように対応する店主に対して、子供が見えなかった主人公ハルと同僚のジムは困惑しています。子供のスカーフの色を質問する店主に、ジムは空気を読んで子供が見えていたかのように振る舞いますが、ハルは店主にも子供が見えていないのだと推察するという流れになっています。
今回は、データセットの整形時に5行を入力、続きの5行を出力としているため、全体の約半分は入力用データになっていません。さらにそこから訓練データとテストデータに分けているので、今回の入力部分がそっくりそのまま学習された確率は低そうです。
とはいえ「スカーフ」という単語はこのエピソードにしか出てこないので、学習されていればうまく引っかかるのではないかと期待しました。また、見えているかのような会話から逆に見えていないことを質問するという流れが再現できるかも気になりました。
結果は以下の通りです。3回繰り返して生成を試しています。
『「あなた、何かやましいことでも……』から続く文章が生成されましたが、学習データを確認したところ、そのような文章は含まれていませんでした。そのあとの文章も、意味的つながりが弱くなっている印象です。
一方、良い点としては、上手く言い逃れようとしているジムに対してハルが何かを言おうとしているという文脈にも捉えられるので、「何かやましいことでも?」と尋ねる言葉が来る可能性は高いかもしれません。
文体も(他の人には分かりづらいかもしれませんが)筆者らしさが表現されているように感じます。特に私の文章の特徴の一つとして、Twitter小説を執筆していた影響から短い文章で簡潔に表現する傾向があります。
特に下記の表現は簡潔でありつつ比喩的な想像力を感じる文章で、評価できると思います。
『銃弾が皮膚をえぐるように、朝の空気が鼓膜を揺さぶった。』
『きっと本棚の奥にしまってあるはずのメモを見落としたのだ。』
ちなみに「パァン」は作中で銃の発射音として一度使われていました。「ペトラ」は登場人物の一人です。
おわりに
ひとまず動作確認ができてよかったです。
3.6Bなのでまだまだパラメータ数が小さいですが、いずれより大きなパラメータ数のモデルを広く普及しているGPUで使えるようになると思うので、今後も様々な活用方法を探っていきたいと思います。