ローカルでGPT2-mediumをファインチューニングしてみた

上記を参考にすればまぁできます。以上!
だけだとあれなので自分のちょっと詰まったところをメモ。

前回の記事を読んでいること前提です。

①必要なライブラリをインストールする

pip install git+https://github.com/huggingface/transformers
git clone https://github.com/huggingface/transformers
pip install -r ./transformers/examples/pytorch/language-modeling/requirements.txt
cd models/rinna
git clone https://huggingface.co/rinna/japanese-gpt2-medium
cd C:\convogpt

②必要なディレクトリ及びファイルを用意する
train_dataフォルダ作成
C:\convogpt\train_data\XXX.txt
学習元テキストを入れる(文字コードUTF-8にするの忘れないように)
自分は以下のクトゥルフ神話シナリオを学習させてみました。
(17956文字)

outputフォルダ作成
C:\convogpt\output

③W&BのAPIキーを取得

以下参照

④W&BのAPIキーを仮想環境に登録する

wandb login

取得したAPIキーを入力してEnter

⑤スクリプトを実行(Win環境の場合)

python ./transformers/examples/pytorch/language-modeling/run_clm.py ^
  --model_name_or_path=./models/rinna/japanese-gpt2-medium ^
  --train_file=./train_data/XXX.txt ^
  --do_train ^
  --num_train_epochs=3 ^
  --save_steps=10000 ^
  --block_size 512 ^
  --save_total_limit=3 ^
  --per_device_train_batch_size=1 ^
  --output_dir=./output ^
  --overwrite_output_dir ^
  --use_fast_tokenizer=False

⑥スクリプト作成

"C:\convogpt\train.py"
>ValueError: The following `model_kwargs` are not used by the model: ['bad_word_ids'] (note: typos in the generate arguments will also show up in this list)エラーがでたのでbad_word_ids引数は使っていない

from transformers import T5Tokenizer, AutoModelForCausalLM
import torch

tokenizer = T5Tokenizer.from_pretrained("models/rinna/japanese-gpt2-medium")
tokenizer.do_lower_case = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 学習したモデルを読み込む
model = AutoModelForCausalLM.from_pretrained("output/")
model.to(device)
model.eval()

# 初めの文章
prompt = "九州の玄界灘に浮かぶ、島民のほとんどが漁業に従事している有人島、恵美ヶ島"
# 生成する文章の数
num = 1 

input_ids = tokenizer.encode(prompt, return_tensors="pt",add_special_tokens=False).to(device)
with torch.no_grad():
    output = model.generate(
        input_ids,
        max_length=100, # 最長の文章長
        min_length=100, # 最短の文章長
        do_sample=True,
        top_k=500, # 上位{top_k}個の文章を保持
        top_p=0.95, # 上位{top_p}%の単語から選択する。例)上位95%の単語から選んでくる
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        num_return_sequences=num # 生成する文章の数
    )
decoded = tokenizer.batch_decode(output,skip_special_tokens=True)
for i in range(num):
  print(decoded[i])

⑦スクリプト実行

python train.py

>九州の玄界灘に浮かぶ、島民のほとんどが漁業に従事している有人島、恵美ヶ島。 島の環境は「秘境の真珠」とも呼ばれ、透明度の 高い美しい海が漂う美しい島だった。 島は、本土との往来が途絶しているうえ、数年前までは海賊の潜伏地などとしても恐れられて いたことから、島民の一部は独自の文化を築いていた。 島に漂着した人間は、島の文化に誇りを持ち、人前に出て話をしたり、他の

おお!!!それっぽいです!!学習されているかはいまいちわかりませんが動いたのでとりあえず満足です!!!!

【追記】
再び仮想環境に簡単に入れるcmdファイル作っておきました。
venv.cmdを実行すると簡単に仮想環境に入れます。
"C:\convogpt\venv.cmd"

cmd /k env\Scripts\activate


いいなと思ったら応援しよう!