ローカルでGPT2-mediumをファインチューニングしてみた
上記を参考にすればまぁできます。以上!
だけだとあれなので自分のちょっと詰まったところをメモ。
前回の記事を読んでいること前提です。
①必要なライブラリをインストールする
pip install git+https://github.com/huggingface/transformers
git clone https://github.com/huggingface/transformers
pip install -r ./transformers/examples/pytorch/language-modeling/requirements.txt
cd models/rinna
git clone https://huggingface.co/rinna/japanese-gpt2-medium
cd C:\convogpt
②必要なディレクトリ及びファイルを用意する
train_dataフォルダ作成
C:\convogpt\train_data\XXX.txt
学習元テキストを入れる(文字コードUTF-8にするの忘れないように)
自分は以下のクトゥルフ神話シナリオを学習させてみました。
(17956文字)
outputフォルダ作成
C:\convogpt\output
③W&BのAPIキーを取得
以下参照
④W&BのAPIキーを仮想環境に登録する
wandb login
取得したAPIキーを入力してEnter
⑤スクリプトを実行(Win環境の場合)
python ./transformers/examples/pytorch/language-modeling/run_clm.py ^
--model_name_or_path=./models/rinna/japanese-gpt2-medium ^
--train_file=./train_data/XXX.txt ^
--do_train ^
--num_train_epochs=3 ^
--save_steps=10000 ^
--block_size 512 ^
--save_total_limit=3 ^
--per_device_train_batch_size=1 ^
--output_dir=./output ^
--overwrite_output_dir ^
--use_fast_tokenizer=False
⑥スクリプト作成
"C:\convogpt\train.py"
>ValueError: The following `model_kwargs` are not used by the model: ['bad_word_ids'] (note: typos in the generate arguments will also show up in this list)エラーがでたのでbad_word_ids引数は使っていない
from transformers import T5Tokenizer, AutoModelForCausalLM
import torch
tokenizer = T5Tokenizer.from_pretrained("models/rinna/japanese-gpt2-medium")
tokenizer.do_lower_case = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 学習したモデルを読み込む
model = AutoModelForCausalLM.from_pretrained("output/")
model.to(device)
model.eval()
# 初めの文章
prompt = "九州の玄界灘に浮かぶ、島民のほとんどが漁業に従事している有人島、恵美ヶ島"
# 生成する文章の数
num = 1
input_ids = tokenizer.encode(prompt, return_tensors="pt",add_special_tokens=False).to(device)
with torch.no_grad():
output = model.generate(
input_ids,
max_length=100, # 最長の文章長
min_length=100, # 最短の文章長
do_sample=True,
top_k=500, # 上位{top_k}個の文章を保持
top_p=0.95, # 上位{top_p}%の単語から選択する。例)上位95%の単語から選んでくる
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
num_return_sequences=num # 生成する文章の数
)
decoded = tokenizer.batch_decode(output,skip_special_tokens=True)
for i in range(num):
print(decoded[i])
⑦スクリプト実行
python train.py
おお!!!それっぽいです!!学習されているかはいまいちわかりませんが動いたのでとりあえず満足です!!!!
【追記】
再び仮想環境に簡単に入れるcmdファイル作っておきました。
venv.cmdを実行すると簡単に仮想環境に入れます。
"C:\convogpt\venv.cmd"
cmd /k env\Scripts\activate