mlx-lm のバージョンアップ:stablelm-2-zephyr-1_6b 向け対話型terminalスクリプト
Mistral-Mediumのもとになったモデルがmiqu-70Bとしてリークされたということで話題ですが、mlxのほうは着実にバージョンがアップされています。それで mlx-lmもバージョンがあがっています。
特徴的なのは、generateに、"--colorize" の引数がついて、生成された単語の確率によって単語に色がつくことになったことでしょう。
この確率については理解がよくできてませんが、自分にとって問題だったのは、以前に作ったinteractive terminalにエラーがでてきたことです。
使っていた generate_step関数にもprobability が出力されるようになったため、ちょっと手直しが必要になりました。
参考にしたのは、mlx-uiのコードです。
以前の記事で書いた stabilityai/stablelm-2-zephyr-1_6b に特化したスクリプトは以下のとおりです。 2023.2.2間違い部分をコメントアウト
import time
import datetime
import mlx.core as mx
from mlx_lm import load
from mlx_lm.utils import generate_step
# 翻訳
from easynmt import EasyNMT
model_name = "stabilityai/stablelm-2-zephyr-1_6b"
# 翻訳モデルの指定 引数でモデルを指定。予め用意されたモデルが色々あるので、easynmtのgithub参照
translation_model = EasyNMT('mbart50_m2m')
# LLMのTemperatureと生成されるテキストのmax_tokensの設定 dolphin 0.7 nous-helmes 0.8 が推奨値
temperature=0.7
max_tokens=750
# 最初のシステムメッセージを設定:
system_message = "You are a helpful assistant."
# 対話の発言回数を設定 とりあえず10回(5回の対話やりとり)の発言のみを保持
number_memory = 10
# 初期プロンプトを設定(システムメッセージと最初のユーザー入力のプレースホルダー){{ }}でプレースホルダーの予約
initial_prompt = f"<|endoftext|>system\n{system_message}<|endoftext|>\n<|endoftext|>user\n{{first_user_input}}<|endoftext|>\n"
# 最初のユーザー入力を保持するためのフラグと変数
is_first_input = True
first_user_input = ""
# 初期状態の翻訳フラグ (Falseは翻訳がオフ、Trueは翻訳がオン:初期値はFalse)
translation_enabled = False
# 会話の履歴を保持するリスト (ユーザーとアシスタントの会話)
conversation_history = []
conversation_all = []
# モデルとトークナイザーをロード trust_remote_codeをTrueのパターン
model, tokenizer = load(model_name,tokenizer_config={"trust_remote_code": True})
# 対話記録のための関数
def update_conversation_history(user_prompt, assistant_response):
global conversation_history # この行を追加
# conversation_historyを更新するための関数 最新のユーザーの入力とaiの応答を追加する
conversation_history.append(f"<|endoftext|>user\n{user_prompt}<|endoftext|>\n<|endoftext|>assistant\n")
conversation_history.append(f"{assistant_response}<|endoftext|>\n")
# conversation_historyが最新の10個の要素のみを保持するようにする
conversation_history = conversation_history[-number_memory:]
def add_to_conversation_all(user_prompt, assistant_response):
global conversation_all # この行を追加
# conversation_allに対話を追加するための関数
conversation_all.append(f"<|endoftext|>user\n{user_prompt}<|endoftext|>\n")
conversation_all.append(f"<|endoftext|>assistant\n{assistant_response}<|endoftext|>\n")
# テキスト生成のための関数
def produce_text(the_prompt, the_model):
tokens = []
skip = 0
REPLACEMENT_CHAR = "\ufffd"
# generate_step関数の結果をイテレートするためのforループ versionがあがったようなので、token, _ から (token, prob), n
for (token, prob),n in zip(generate_step(mx.array(tokenizer.encode(the_prompt)), the_model, temperature),
range(max_tokens)):
# EOS tokenが出現したときにループを抜ける
if token == tokenizer.eos_token_id:
break
tokens.append(token.item())
# 生成されたトークンからテキストをデコード
generated_text = tokenizer.decode(tokens)
# 置換文字である壊れた文字があればを取り除く
generated_text = generated_text.replace(REPLACEMENT_CHAR,'')
# 以前に出力されたテキストをスキップして新しいテキストをyield
# npaka さんのnoteにしたがってif をいれてみる
# if REPLACEMENT_CHAR not in generated_text[skip:]:
# if文がはいったのでインデントする → やっぱりやめた方がいい感じ
yield generated_text[skip:]
skip = len(generated_text) # スキップする文字数を更新
# 生成したテキストを表示する関数と、produce_textに渡すfull_promptを作成と、会話履歴直近を保存
def show_chat(user_input):
global conversation_history, conversation_all, initial_prompt, is_first_input, first_user_input, system_message, translation_enabled
# 上は、これによってグローバル変数として扱う
# 最初のユーザー入力を確認し、保持
if is_first_input:
if user_input in ('/show', '/clear','Continue from your last line.', '/save','/reset'):
print('No initial prompt, yet')
return
else:
first_user_input = user_input
is_first_input = False # 最初の入力が処理されたのでフラグを更新
# プロントに用いる会話履歴をクリアーにするコマンドの実行部分
if user_input == "/clear":
conversation_history = []
print("===! Conversation history cleared! ===")
return
# initial promptからすべての会話履歴をクリアーにするコマンドの実行部分
if user_input == "/reset":
conversation_all = []
conversation_history = []
first_user_input =""
is_first_input = True
print("===! ALL Conversation history and Initial Prompt cleared! ===")
return
# 会話履歴を保存するコマンドの実行部分
if user_input == "/save":
# 現在の日付と時刻を取得し、ファイル名にするための文字列を作成
current_time = datetime.datetime.now()
timestamp_str = current_time.strftime("%Y%m%d_%H%M%S") # 例: '20240126_153305'
filename = f"conversation_all_{timestamp_str}.txt" # ファイル名の生成
with open(filename, 'w', encoding='utf-8') as file: # ファイルに書き込む
file.write('\n'.join(str(item) for item in conversation_all))
print(f"=== Conversation history saved as {filename}! ===")
return
# システムプロンプトとイニシャルプロンプトを表示する実行
if user_input == "/show":
print("=== System Prompt and Initial Prompt ===")
print("System: ", system_message)
print("Initial: ", first_user_input)
print("")
return
# 会話履歴を更新し、プロンプトを構築
conversation_history.append(f"<|endoftext|>user\n{user_input}<|endoftext|>\n<|endoftext|>assistant\n")
# 初期プロンプトのプレースホルダーを最初のユーザー入力で置換 ここでは置き換え前のフレーズは一重の { }となる。
full_prompt = initial_prompt.replace("{first_user_input}", first_user_input) + "".join(conversation_history)
# systemおよびiitialを保存する
# formatted_initial_prompt = initial_prompt.replace("{first_user_input}", first_user_input)
# conversation_all.append(formatted_initial_prompt) 毎回付け加わってしまう
print("\nStableLM: ", end="", flush=True)
full_response = ""
for chunk in produce_text(full_prompt, model): #produce_text関数を呼び出す
full_response += chunk # 生成されたテキスト全文を収納しておくため、あとで使うかも
if translation_enabled == False:
print(chunk, end="", flush=True) # chunkテキストを出力
time.sleep(0.1) # 生成中のタイピング効果をシミュレートするための遅延。適当な値にするか、必要ならコメントアウト。
if translation_enabled == True:
translated = translation_model.translate(full_response,target_lang="ja",max_length=1000)
print(translated) # 翻訳されたchunkテキストを出力
print("\n")
# conversation_allとconversation_historyを更新する処理
add_to_conversation_all(user_input, full_response)
update_conversation_history(user_input, full_response)
# ユーザーからの入力を処理し、整形されたテキストを返す関数
def get_user_input():
user_input = ""
multi_line = False
while True:
line = input("User: ")
if line.startswith('"""'):
if multi_line:
multi_line = False
if line.endswith('"""'):
user_input += line[:-3] # 末尾の引用符を除去する
else:
user_input += line + "\n"
break
else:
multi_line = True
if line.endswith('"""') and len(line) > 3:
user_input += line[3:-3] + "\n" # 先頭と末尾の引用符を除去する
break
else:
user_input += line[3:] + "\n" # 先頭の引用符を除去する
else:
if multi_line:
if line.endswith('"""'):
user_input += line[:-3] # 末尾の引用符を除去する
break
else:
user_input += line + "\n"
else:
user_input += line
break
return user_input.strip()
# メイン部分
def main():
global translation_enabled # 翻訳モードのフラグを他でも使うため
print("\n⭐️⭐️⭐️ MLX Language Model Interactive Terminal ⭐️⭐️⭐️\n")
print("Model: ", model_name)
print("-" * 70)
print("Available Commands:")
print(
" type `c`: Continue from the last response.\n"
" type `q`: Quit the application.\n"
" type `/history`: View recent conversation history.\n"
" type `/save`: Save all conversation history to a file.\n"
" type `/reset`: Restart with a new initial prompt.\n"
" type `/show`: Display system and initial prompt.\n"
" type `/translate`: Toggle JA-EN/EN-JA translation.")
print("-" * 70)
print('For multi-line input, enclose your text with triple quotes (""") ')
print('at the beginning and the end.')
print("=" * 70 + "\n")
while True:
formatted_input = get_user_input()
# ユーザーが終了コマンドを入力した場合
if formatted_input.lower() == 'q':
print("\n" + "=" * 70 + "\n")
print("Exiting the interactive terminal.")
print("\n")
break # スクリプトが終了する
# ユーザー入力が何もなかった場合に警告を出して次の入力を待つ
if formatted_input == '':
print("Warning: Empty or null user input detected. Please try again.")
continue # Whileのループの頭に戻る
# ユーザーが会話履歴を表示するコマンドを入力した場合
if formatted_input == '/history':
print("\n===== Recent Conversation History =====\n")
print("".join(conversation_history).strip())
continue # 会話履歴を表示し、次の入力を待つ
# /translate コマンドが呼び出されたとき、フラグをトグルする
if formatted_input == "/translate":
translation_enabled = not translation_enabled
# 状態を英語で表示する
print(f"Translation is now {'on' if translation_enabled else 'off'}.")
continue # 次の入力を待つ
# 続きを促すショートカットがあったら、続きを促す文章にする
if formatted_input == 'c':
formatted_input = 'Continue from your last line.'
# 実質、作業をする関数に入力された文章を引き渡し
if translation_enabled:
translated = translation_model.translate(formatted_input, target_lang="en", max_length=1000)
else:
translated = formatted_input
show_chat(translated.strip())
# このスクリプトを走らせるためのコマンド
if __name__ == "__main__":
main()
下記の行が修正後の文です。(引用でインデントが崩れてますので注意。)
あと、切れ目はすべて、<|endoftext|> だということなので、それに入れ替えています。この辺りについてはさらに学びたいなと思っていますが、なかなか理解が進んでいません。
system_message は適当に書いていますので、適宜書き換えてください。
1.6bのモデルとしては、対話が成立するので面白いです。ただし、事実については作話してるので、事実確認には用いない方がいいでしょう。知識が薄っぺらだけど話し上手なAIとして対話を楽しめるのではないでしょうか?
ちなみに参考にしているmlx-uiの選択モデルがふえています。
話題の CodeLlama-70b-hf-4bit-MLX も入っていますので、お手軽に試したい人はmlx-uiを使ってみるといいかと思います。
Pythonの自学学習とterminalで完結したいだけで、作っているのでこだわりがない人は、mlx-uiを使った方がお手軽かと思います。