短期会話ログ機能を追加-japanese-stablelm-instruct-gamma-7bのGUI版
StabilityAIから新しいモデルが公開されています。先にも記事にしましたが、会話ログが保持できる仕様へ修正しました。会話の継続性を維持するためにはこの機能がどうしても必要で、StabilityAIのモデルでは今まで出来ていませんでした。
新モデルの-instruct版、他にBASEモデルなどもあります。
stabilityai/japanese-stablelm-instruct-gamma-7b
stabilityai/japanese-stablelm-3b-4e1t-instruct
このコードはV2も使用できます。V2を使用するメリットはVRAMが小さくてすむことだと思います。3b-4e1tモデルでV2と同様の性能になるようなのでV2を使用する機会はほとんど無いかと思います。
改良点は会話ログをプロンプトに追加する形で会話の連続性を維持させたことです。キャラ付けをするのは、コツが必要なのでここでは説明しません。GUIはいつもどおりgradioです。API化するときに変更をできる限り小さくするために、会話ログの保持をgradioで行っています。API化したときはクライアント側で保持することになります。この方法でサーバ化したときに複数のクリアントで共有できるようになります。あるいは、同じクライアントでもプロンプトを変えて異なる用途で使えるようになります。例えば会話用の時であったり、翻訳用であったりです。
環境
憂鬱な環境設定ですが、過去記事の通りです。以下参照。
CUDA関連
nvcc-Vで調べた環境です。11.8以上で大丈夫なはずです。
nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Jun_13_19:16:58_PDT_2023
Cuda compilation tools, release 12.2, V12.2.91
Build cuda_12.2.r12.2/compiler.32965470_0
その他
Transformerとgradioが必要です。
pip install transformers accelerate bitsandbytes
pip install sentencepiece einops
pip install gradio
コード
モデルとtokenizer
V2用、7B用、3B用が選べます。コメントアウトで選択してください。
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
#-- -- japanese-stablelm-instruct-alpha-7b-v2のときは以下---
#from transformers import LlamaTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("stabilityai/japanese-stablelm-instruct-gamma-7b")
model = AutoModelForCausalLM.from_pretrained(
"stabilityai/japanese-stablelm-instruct-gamma-7b",
torch_dtype="auto",
)
#tokenizer = AutoTokenizer.from_pretrained("stabilityai/japanese-stablelm-3b-4e1t-instruct")
#model = AutoModelForCausalLM.from_pretrained(
# "stabilityai/japanese-stablelm-3b-4e1t-instruct",
# trust_remote_code=True,
# torch_dtype="auto",
#)
#-- -- japanese-stablelm-instruct-alpha-7b-v2のときは以下---
#tokenizer = LlamaTokenizer.from_pretrained(
# "novelai/nerdstash-tokenizer-v1", additional_special_tokens=["▁▁"], legacy=False)
#model = AutoModelForCausalLM.from_pretrained(
# "stabilityai/japanese-stablelm-instruct-alpha-7b-v2",
# trust_remote_code=True,
# torch_dtype=torch.float16,
# variant="fp16",
# )
model.eval()
if torch.cuda.is_available():
model = model.to("cuda")
生成部関数
コード中間ぐらいのコメント
#会話ヒストリ作成。プロンプトに追加する。
から
print("prompt=",prompt)
までが記憶保持のコードです。必要なければ削除出来ます。後半の
if len(talk_log_list)>log_len: talk_log_list=talk_log_list[2:] #ヒストリターンが指定回数を超えたら先頭(=一番古い)の会話(入力と応答)を削除 talk_log_list.append("\n" +"###"+ "応答:"+"\n" + out .replace("\n" ,""))
も記憶保持用のコードです。
def genereate(sys_msg, user_query,user,max_token,get_temperature, talk_log_list ,log_f,log_len ):
max_token=int(max_token)
get_temperature=float(get_temperature)
user_inputs = {
"user_query": user_query,
"inputs": user,
}
prompt = build_prompt(sys_msg ,**user_inputs)
#会話ヒストリ作成。プロンプトに追加する。
log_len = int(log_len)
if log_f==True and log_len >0: # 履歴がTrueでログ数がゼロでなければtalk_log_listを作成
sys_prompt=prompt.split("### 入力:")[0]
talk_log_list.append( " \n\n"+ "### 入力:"+ " \n" + user+ " \n" )
new_prompt=""
for n in range(len(talk_log_list)):
new_prompt=new_prompt + talk_log_list[n]
prompt= sys_prompt + new_prompt+" \n \n"+ "### 応答:"+" \n"
print("prompt=",prompt)
input_ids = tokenizer.encode(
prompt,
add_special_tokens=False,
return_tensors="pt"
)
# パッドトークンIDの設定
pad_token_id = tokenizer.eos_token_id # パディングトークンIDをeos_token_idに設定
tokens = model.generate(
input_ids.to(device=model.device),
max_new_tokens=max_token,
temperature=get_temperature,
top_p=0.95,
do_sample=True,
pad_token_id= pad_token_id,
)
all_out = tokenizer.decode(tokens[0][input_ids.shape[1]:], skip_special_tokens=True).strip()
out=all_out.split("###")[0]
if len(talk_log_list)>log_len:
talk_log_list=talk_log_list[2:] #ヒストリターンが指定回数を超えたら先頭(=一番古い)の会話(入力と応答)を削除
talk_log_list.append("\n" +"###"+ "応答:"+"\n" + out .replace("\n" ,""))
return out, all_out, prompt ,talk_log_list
GUI
gradioによる記述です。
# GradioのUIを定義します
import gradio as gr
with gr.Blocks() as webui:
gr.Markdown("japanese-stablelm-instruct-alpha-7b-v2 prompt test")
with gr.Row():
with gr.Column():
sys_msg = gr.Textbox(label="sys_msg", placeholder=" システムプロンプト")
user_query = gr.Textbox(label="user_query", placeholder="命令を入力してください")
user = gr.Textbox(label="入力", placeholder="ユーザーの会話を入力してください")
with gr.Row():
log_len = gr.Number(10, label="履歴ターン数")
log_f = gr.Checkbox(True, label="履歴有効・無効")
with gr.Row():
max_token = gr.Number(100, label="max out token")
temperature = gr.Number(0.7, label="temperature,")
with gr.Row():
prompt_input = gr.Button("Submit prompt",variant="primary")
log_clr = gr.Button("ログクリア",variant="secondary")
with gr.Column():
out_data=[gr.Textbox(label="システム"),
gr.Textbox(label="tokenizer全文"),
gr.Textbox(label="プロンプト"),
gr.Textbox(label="会話ログリスト")]
prompt_input.click(gradio_genereate, inputs=[sys_msg, user_query, user, max_token, temperature,log_f,log_len], outputs=out_data )
log_clr .click(gradio_clr)
# Gradioアプリケーションを起動します
webui.launch()
起動すると以下のようなGUIになります。
設定できるパラメータは
履歴ターン数
履歴有効・無効
max out token
temperature
になります。生成の終了を正確に行いたかったのですが、方法がよくわかりません。max out tokenで調整は出来ますが小さくしすぎると文が途中でと切れますし、大きくすると生成時間が長くなります。会話の場合は100前後で良さそうです。履歴については長くなるとプロンプトが大きくなって生成に時間がかかったり、VRAMが消費されます。5〜10程度で良さそうです。会話をしないのであれば、履歴有効・無効を無効にした方が過去に引きずられることなく毎回新規に回答作成してくれます。
3種類のプロンプト
プロンプトは
sys_msg
user_query
user
の3種類です。
プロンプトでは、sys_msgは最初に直接挿入されています。
user_queryは「指示」となっていてuserに対する回答の仕方を記述します。
userは「入力」となっており、文字通りユーザーの会話入力になります。
sys_msgはAIとのやり取りの場面を教える役目のようです。実際の回答に対するAIの反応については指示であるuser_queryに記載します。
StabilityAIのコードでは
sys_msg 以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。"
user_query 与えられたことわざの意味を小学生でも分かるように教えてください。
inputs 情けは人のためならず
となっています。キャタ付けはもう少し使わけの工夫が必要ですが、かなり効果的にプロンプトのみでできそうです。
コード全体
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
#-- -- japanese-stablelm-instruct-alpha-7b-v2のときは以下---
#from transformers import LlamaTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("stabilityai/japanese-stablelm-instruct-gamma-7b")
model = AutoModelForCausalLM.from_pretrained(
"stabilityai/japanese-stablelm-instruct-gamma-7b",
torch_dtype="auto",
)
#tokenizer = AutoTokenizer.from_pretrained("stabilityai/japanese-stablelm-3b-4e1t-instruct")
#model = AutoModelForCausalLM.from_pretrained(
# "stabilityai/japanese-stablelm-3b-4e1t-instruct",
# trust_remote_code=True,
# torch_dtype="auto",
#)
#-- -- japanese-stablelm-instruct-alpha-7b-v2のときは以下---
#tokenizer = LlamaTokenizer.from_pretrained(
# "novelai/nerdstash-tokenizer-v1", additional_special_tokens=["▁▁"], legacy=False)
#model = AutoModelForCausalLM.from_pretrained(
# "stabilityai/japanese-stablelm-instruct-alpha-7b-v2",
# trust_remote_code=True,
# torch_dtype=torch.float16,
# variant="fp16",
# )
model.eval()
if torch.cuda.is_available():
model = model.to("cuda")
talk_log_list=[] # gradioで保持するための初期化
def genereate(sys_msg, user_query,user,max_token,get_temperature, talk_log_list ,log_f,log_len ):
max_token=int(max_token)
get_temperature=float(get_temperature)
user_inputs = {
"user_query": user_query,
"inputs": user,
}
prompt = build_prompt(sys_msg ,**user_inputs)
#会話ヒストリ作成。プロンプトに追加する。
log_len = int(log_len)
if log_f==True and log_len >0: # 履歴がTrueでログ数がゼロでなければtalk_log_listを作成
sys_prompt=prompt.split("### 入力:")[0]
talk_log_list.append( " \n\n"+ "### 入力:"+ " \n" + user+ " \n" )
new_prompt=""
for n in range(len(talk_log_list)):
new_prompt=new_prompt + talk_log_list[n]
prompt= sys_prompt + new_prompt+" \n \n"+ "### 応答:"+" \n"
print("prompt=",prompt)
input_ids = tokenizer.encode(
prompt,
add_special_tokens=False,
return_tensors="pt"
)
# パッドトークンIDの設定
pad_token_id = tokenizer.eos_token_id # パディングトークンIDをeos_token_idに設定
tokens = model.generate(
input_ids.to(device=model.device),
max_new_tokens=max_token,
temperature=get_temperature,
top_p=0.95,
do_sample=True,
pad_token_id= pad_token_id,
)
all_out = tokenizer.decode(tokens[0][input_ids.shape[1]:], skip_special_tokens=True).strip()
out=all_out.split("###")[0]
if len(talk_log_list)>log_len:
talk_log_list=talk_log_list[2:] #ヒストリターンが指定回数を超えたら先頭(=一番古い)の会話(入力と応答)を削除
talk_log_list.append("\n" +"###"+ "応答:"+"\n" + out .replace("\n" ,""))
return out, all_out, prompt ,talk_log_list
def build_prompt(sys_msg ,user_query, inputs="", sep="\n\n### "):
p = sys_msg
roles = ["指示", "応答"]
msgs = [": \n" + user_query, ": \n"]
if inputs:
roles.insert(1, "入力")
msgs.insert(1, ": \n" + inputs)
for role, msg in zip(roles, msgs):
p += sep + role + msg
return p
# Gradioからアクセスするときの関数、talk_log_listを保持したりクリアするため
def gradio_genereate(sys_msg, user_query,user,max_token,get_temperature, log_f,log_len ):
global talk_log_list
out, all_out, prompt ,talk_log_list=genereate(sys_msg, user_query,user,max_token,get_temperature, talk_log_list,log_f,log_len )
return out, all_out, prompt,talk_log_list
def gradio_clr():
global talk_log_list
talk_log_list=[]
# GradioのUIを定義します
import gradio as gr
with gr.Blocks() as webui:
gr.Markdown("japanese-stablelm-instruct-alpha-7b-v2 prompt test")
with gr.Row():
with gr.Column():
sys_msg = gr.Textbox(label="sys_msg", placeholder=" システムプロンプト")
user_query =gr.Textbox(label="user_query", placeholder="命令を入力してください")
user =gr.Textbox(label="入力", placeholder="ユーザーの会話を入力してください")
with gr.Row():
log_len =gr.Number(10, label="履歴ターン数")
log_f =gr.Checkbox(True, label="履歴有効・無効")
with gr.Row():
max_token = gr.Number(100, label="max out token")
temperature = gr.Number(0.7, label="temperature,")
with gr.Row():
prompt_input = gr.Button("Submit prompt",variant="primary")
log_clr = gr.Button("ログクリア",variant="secondary")
with gr.Column():
out_data=[gr.Textbox(label="システム"),
gr.Textbox(label="tokenizer全文"),
gr.Textbox(label="プロンプト"),
gr.Textbox(label="会話ログリスト")]
prompt_input.click(gradio_genereate, inputs=[sys_msg, user_query, user, max_token, temperature,log_f,log_len], outputs=out_data )
log_clr .click(gradio_clr)
# Gradioアプリケーションを起動します
webui.launch()
キャラ付け
色々とプロンプトを試して見ました。プロンプトはよく効くので、上手く記述するとできそうです。ただし、結構反抗的でもありますね。そこもキャラとしては面白いところです。