LoRAとプロンプトでAIキャラを作る。rinna3.6BにLoRAとプロンプトで性格付けをする。
ローカルのLLMでAIキャラを動かすときには性格付けが必要です。前回の記事でサブカルにも強そうという結果であった
japanese-gpt-neox-3.6b-instruction-ppo
を用いて好みのAIキャラを作れるように、プロンプトを細かく調整したり、LoRAの効果を見る、会話ヒストリによる短期記憶の調整、重要な生成パラメータの調整ができるweb-uiを作成しました。モデルをロードしたまま、様々なプロンプトのテストが出来るので、AIキャラ作成も一気に進みます。
rinnaの各モデルにはChatGPTやLlama系にあるsystemのプロンプトがありません。性格付けを行うためには、LoRAによるファインチューニングか、プロンプトで行うしか方法はありません。過去記事ではプロンプトで性格を変えるwebuiの紹介をしましたが、色々と制限があって不十分でした。そこで、今回は当面のAIキャラのベースモデルをrinnaの3.6b-instruction-ppoに絞りましたので、より細かいプロンプトの設定とLoRAファインチューニングの効果、及び両者同時に設定して更に幅広い性格付を確かめることができるように工夫しました。改良点は以下のとおり。
web-uiの改良点
コードを直接変更して設定できる項目。
-学習済みLoRAモデルの指定
-LoRAの有効・無効設定
-デフォルトプロンプトの指定
web-uiで変更できる項目
-デフォルトプロンプトの使用の有無
-Few-shotプロンプトの設定
-会話ヒストリターン数の設定
-会話ヒストリを使わない設定
-会話ヒストリのクリア
-ZERO-shotプロンプトの利用
動作環境
ベースは前回の記事と同じです。
japanese-gpt-neox-3.6b-instruction-ppoの動作環境はGPT2と同じです。
あるいは、以下の記事で作成した環境をベースにします。
LoRAを利用をするために以下を追加でインストールします。npakaさんの記事から拝借しています。参考に記事へのURLを置きました。
# PEFTのインストール
!pip install -Uqq git+https://github.com/huggingface/peft.git
!pip install -Uqq transformers datasets accelerate bitsandbytes
!pip install sentencepiece
LoRA学習済みモデルのフォルダーをweb-ui実行フォルダーにコピーしておきます。LoRAはnpakaさんの記事の学習コードを実行してもいいですし、TexstGeneration-webuiのトレーニングでも大丈夫です。
学習用データセットは以下のような書式のプレーンテキストです。
ユーザー:あなたの名前は何ですか?
システム:わたしの名前はめぐだよ。
ユーザー:めぐはどんな性格ですか?
システム:めぐは、賢くて、おちゃめで、少しボーイッシュ、天真爛漫で好奇心旺盛な女子高生だよ。
ユーザー:めぐはどこで生まれたの?
システム:品川区の目黒川の近くで生まれたんだ。
ユーザー:めぐはどんな話し方をするの?
システム:いつもタメ口を使っているよ。
コードの説明
初期化部分
# 初期設定ー>model_name,lora_model_path,LoRA,default_promptの設定
でモデル名、LoRAホルダ名、LoRAの有効無効、デフォルトプロンプトの設定を行います。一般的にLLMではデータの保持にjeson形式が使われているのですが、今回は操作に慣れているリスト形式でデータを保持しています。続けてモデルの読み込みですが、今回は8bit量子化を指定しています。その後#LoRAモデルの準備で LoRA=Trueだった場合にはLoRAを読み込みます。ヒストリ用のリストtalk_log,histry_log,log_lenを設定して初期化を完了します。
1−初期化部分
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
import gradio as gr
# 初期設定ー>model_name,lora_model_path,LoRA,default_promptの設定
model_name ="rinna/japanese-gpt-neox-3.6b-instruction-ppo"
lora_model_path="megu_instructon_all-EP1"
LoRA=True
default_prompt =["ユーザー: あなたの名前は何ですか?","システム: わたしは女子高校生の「めぐ」だよ。"]
#モデルの読み込み
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_name,
load_in_8bit=True,
torch_dtype=torch.float16,
device_map="auto",)
# pad_token_id を設定
model.config.pad_token_id = tokenizer.eos_token_id
#LoRAモデルの準備
if LoRA:
model = PeftModel.from_pretrained(
model,
lora_model_path, #学習済みLoRAのフォルダ
device_map="auto"
)
else:
lora_model_path ="LoRA無効"
talk_log =[]
histry_log =[]
log_len=10
会話の生成
generate()がプロンプトからLLMの出力を得る関数です。入力パラメータはgradioで入力された値になります。各変数の意味は以下の通りです。
system Few-shotプロンプト用の会話入力です。
user ユーザー入力
log_len 会話ヒストリ-のターン数
temperature_val temperature、数値が大きいほど変化に富む出力になる
mx_n_token 出力トークンの最大値
2−会話の生成
def generate(system, user, log_len, temperature_val, mx_n_token):
global talk_log
temperature_set=float(temperature_val)
max_new_tokens_set=int(mx_n_token)
print("temperature=",temperature_set, "max_new_tokens=",max_new_tokens_set)
if system=="":
system_prompt = default_prompt
#systemの文字列からプロンプトを組み立て
else:
system = system.replace(" ","") #スペースを削除
system_prompt = system.split("\n") #改行で1文頃にリスト化
#会話ヒストリ作成。プロンプトに追加する。
log_prompt =[]
log_len = int(log_len)
if log_len>0:
if len(talk_log )>log_len:
talk_log = talk_log [1:] #ヒストリターンが指定回数を超えたら先頭(=一番古い)の会話を削除
for log_p in talk_log:
log_prompt = log_prompt + ["ユーザー: "+log_p[0]]+["システム: "+log_p[1]]
#プロンプトの準備。
if user=="":
prompt = system_prompt + log_prompt
else:
prompt = system_prompt + log_prompt + ["ユーザー: "+ user]
prompt = ( "<NL>".join(prompt) + "<NL>" + "システム: ")
#Tokenizer準備
input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, add_special_tokens=False).input_ids.cuda()
with torch.no_grad(): #generate
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens_set,
do_sample=True,
temperature=temperature_set,
top_p=0.75,
top_k=40,
no_repeat_ngram_size=2,
)
#システムの返事部分を取り出し
outputs =tokenizer.decode(outputs[0].tolist())
output = outputs.replace("<NL>", "\n").replace("</s>", "")
output = output.split("システム: ")[-1:][0]
#ヒストリ記録
talk_log.append([user,output])
talk_p ="ユーザ:"+user+" \n"+"システム:"+output
print( talk_p)
histry_log.append(talk_p)
talk_histry='\n'.join(histry_log) #gradio表示用にリストから文字列に変換
return output, talk_histry, prompt
コードの説明
コメントト通りなので特に説明はいらないと思いますが、system_promptとlog_prompt の各リストが重要な役割をしています。system_promptにはFew-shotプロンプト用の会話リストを格納します。以下の部分です
system = system.replace(" ","") #スペースを削除
system_prompt = system.split("\n") #改行で1文頃にリスト化
またlog_prompt には会話ヒストリのリストが記録されています。
for log_p in talk_log:
log_prompt = log_prompt + ["ユーザー: "+log_p[0]]+["システム: "+log_p[1]]
これにユーザー入力を["ユーザー: "+ user] としてリスト化し、リストの結合を行なっています。
prompt = system_prompt + log_prompt + ["ユーザー: "+ user]
最終的にrinna公式のプロンプトの要求に合うようにリストから文字列に変換しています。rinna公式のプロンプトはHugingHaceで公開されているコードにpromptを出力するprint文を追加すれば確認できるはずです。
prompt = ( "<NL>".join(prompt) + "<NL>" + "システム: ")
with torch.no_grad(): #generate 以下で文章が生成されるので、システムの回答を以下で取り出しています
#システムの返事部分を取り出し
outputs =tokenizer.decode(outputs[0].tolist())
output = outputs.replace("<NL>", "\n").replace("</s>", "")
output = output.split("システム: ")[-1:][0]
最後に、#ヒストリ記録でヒストリ・リストを作成し、gradioの表示に合うように形式を変えています。
3-ヒストリ履歴クリア
gradioのClear histryボタンから呼び出される関数です。会話ヒストリ関係のリストをクリアしています。
def clear_histry():
global talk_log
global histry_log
talk_log = []
histry_log=[]
4-UI部分
gradioで記述しています。今回はやや変数が多めなのでBlosksを使用して記述しています。
with gr.Blocks() as webui:
with gr.Row():
with gr.Column():
with gr.Row():
gr.Markdown("model name: ")
gr.Markdown(model_name)
with gr.Row():
gr.Markdown("LoRA path: ")
gr.Markdown(lora_model_path,scale=1)
sys_prompt = gr.Textbox(label="固定プロンプト(Few-shot)",lines=10, placeholder=" システム(Few-shot)プロンプト入力してください")
user_promp = gr.Textbox(label="ユーザー入力プロンプト",lines=3, placeholder="質問・会話を入力してください")
with gr.Row():
log_len = gr.Number(10, label="History Prompt記憶ターン数")
temperature = gr.Number(0.7, label="temperature")
mx_n_token = gr.Number(200, label="max_new_tokens")
with gr.Row():
prompt_inpu = gr.Button("Submit prompt",variant="primary")
clear_log = gr.Button("Clear histry")
clear_log.click(clear_histry)
with gr.Column():
out_data=[gr.Textbox(label="システム"),
gr.Textbox(label="会話ヒストリ"),
gr.Textbox(label="プロンプト")]
prompt_inpu.click(generate, inputs=[sys_prompt, user_promp, log_len, temperature, mx_n_token], outputs=out_data )
webui.launch()
操作について
起動すると以下の画面になります。これは操作中の画面です。最初は何も入力や出力はありません。
左上にロードされているモデルとLoRAが有効な時はLoRAフォルダへのパスが表示されます。LoRA無効時には”LoRA無効”が表示されます。
ゼロショットプロンプトを使用するとき
History Prompt記憶ターン数を0にします。
ユーザー入力プロンプトは使わずに固定プロンプト(Few-shot)部分に入力します。
デフォルトプロンプトを使用しない
固定プロンプト(Few-shot)に会話を記述するとデフォルトプロンプトは使用されません
会話ヒストリを使用しない場合
History Prompt記憶ターン数を0にします。入力はユーザー入力プロンプトで行います。
Few- shotプロンプトを使う場合。
固定プロンプト(Few-shot)に以下のフォーマットで記述します。LoRA学習と同じフォーマットです。
ユーザー:あなたの名前は何ですか?
システム:わたしの名前はめぐだよ。
ユーザー:めぐはどんな性格ですか?
システム:めぐは、賢くて、おちゃめで、好奇心旺盛な女子高生だよ。
出力について
システム
システムからの返事が表示されています。
会話ヒストリ
過去の会話ヒストリです。
プロンプト
実際にLLMへ入力されたプロンプトが表示されています。公式のプロンプトと差異がないか、ヒストリの追加状況は正しいかなどが確認できます。
注意事項
History Prompt記憶ターン数に数字以外を入力するとエラーになります。
LoRA有効・無効はコードを変更してください。
全てのコード
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
import gradio as gr
# 初期設定ー>model_name,lora_model_path,LoRA,default_promptの設定
model_name ="rinna/japanese-gpt-neox-3.6b-instruction-ppo"
lora_model_path="megu_instructon_all-EP1"
LoRA=True
default_prompt =["ユーザー: あなたの名前は何ですか?","システム: わたしは女子高校生の「めぐ」だよ。"]
#モデルの読み込み
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_name,
load_in_8bit=True,
torch_dtype=torch.float16,
device_map="auto",)
# pad_token_id を設定
model.config.pad_token_id = tokenizer.eos_token_id
#LoRAモデルの準備
if LoRA:
model = PeftModel.from_pretrained(
model,
lora_model_path, #学習済みLoRAのフォルダ
device_map="auto"
)
else:
lora_model_path ="LoRA無効"
talk_log =[]
histry_log =[]
log_len=10
def generate(system, user, log_len, temperature_val, mx_n_token):
global talk_log
temperature_set=float(temperature_val)
max_new_tokens_set=int(mx_n_token)
print("temperature=",temperature_set, "max_new_tokens=",max_new_tokens_set)
if system=="":
system_prompt = default_prompt
#systemの文字列からプロンプトを組み立て
else:
system = system.replace(" ","") #スペースを削除
system_prompt = system.split("\n") #改行で1文頃にリスト化
#会話ヒストリ作成。プロンプトに追加する。
log_prompt =[]
log_len = int(log_len)
if log_len>0:
if len(talk_log )>log_len:
talk_log = talk_log [1:] #ヒストリターンが指定回数を超えたら先頭(=一番古い)の会話を削除
for log_p in talk_log:
log_prompt = log_prompt + ["ユーザー: "+log_p[0]]+["システム: "+log_p[1]]
#プロンプトの準備。
if user=="":
prompt = system_prompt + log_prompt
else:
prompt = system_prompt + log_prompt + ["ユーザー: "+ user]
prompt = ( "<NL>".join(prompt) + "<NL>" + "システム: ")
#Tokenizer準備
input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, add_special_tokens=False).input_ids.cuda()
with torch.no_grad(): #generate
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens_set,
do_sample=True,
temperature=temperature_set,
top_p=0.75,
top_k=40,
no_repeat_ngram_size=2,
)
#システムの返事部分を取り出し
outputs =tokenizer.decode(outputs[0].tolist())
output = outputs.replace("<NL>", "\n").replace("</s>", "")
output = output.split("システム: ")[-1:][0]
#ヒストリ記録
talk_log.append([user,output])
talk_p ="ユーザ:"+user+" \n"+"システム:"+output
print( talk_p)
histry_log.append(talk_p)
talk_histry='\n'.join(histry_log) #gradio表示用にリストから文字列に変換
return output, talk_histry, prompt
def clear_histry():
global talk_log
global histry_log
talk_log = []
histry_log=[]
with gr.Blocks() as webui:
with gr.Row():
with gr.Column():
with gr.Row():
gr.Markdown("model name: ")
gr.Markdown(model_name)
with gr.Row():
gr.Markdown("LoRA path: ")
gr.Markdown(lora_model_path,scale=1)
sys_prompt = gr.Textbox(label="固定プロンプト(Few-shot)",lines=10, placeholder=" システム(Few-shot)プロンプト入力してください")
user_promp = gr.Textbox(label="ユーザー入力プロンプト",lines=3, placeholder="質問・会話を入力してください")
with gr.Row():
log_len = gr.Number(10, label="History Prompt記憶ターン数")
temperature = gr.Number(0.7, label="temperature")
mx_n_token = gr.Number(200, label="max_new_tokens")
with gr.Row():
prompt_inpu = gr.Button("Submit prompt",variant="primary")
clear_log = gr.Button("Clear histry")
clear_log.click(clear_histry)
with gr.Column():
out_data=[gr.Textbox(label="システム"),
gr.Textbox(label="会話ヒストリ"),
gr.Textbox(label="プロンプト")]
prompt_inpu.click(generate, inputs=[sys_prompt, user_promp, log_len, temperature, mx_n_token], outputs=out_data )
webui.launch()
まとめ
今回の大改修によりプロンプトによる出力の調整が大変やりやすくなりました。残念ながらrinna専用ですが、プロンプト構築部分を変更すれば他のモデルにも対応できると思います。また筆者の力不足と時間節約、コード短縮化のためにモデル名などの初期値をコマンドオプションから読み込んだり、webui中で変更する仕様にできませんでした。
参考