見出し画像

⭐️⭐️⭐️ MLX Language Model Interactive Terminal ⭐️⭐️⭐️ さらにバージョンアップ stablelm-2-zephyr-1_6b用

MLX 対話 Terminalのバージョンアップをまたおこないました。

  • h で最初の画面 helpです。

  • r で regenerationします。

  • /tokens で生成するtoken数を途中で変えれます。100~1000の間にしてます。

import time
import datetime
import mlx.core as mx
from mlx_lm import load
from mlx_lm.utils import generate_step
# 翻訳
from easynmt import EasyNMT

################ ここから必要に応じて設定 ######################

model_name = "stabilityai/stablelm-2-zephyr-1_6b"

# 翻訳モデルの指定 引数でモデルを指定。予め用意されたモデルが色々あるので、easynmtのgithub参照
translation_model = EasyNMT('mbart50_m2m')

# LLMのTemperatureと生成されるテキストのmax_tokensの設定 dolphin 0.7 nous-helmes 0.8 が推奨値
temperature=0.7
max_tokens=750
ai_name = "StableLM"  # 好きな表示の名前を設定

# 最初のシステムメッセージを設定
system_message = "This is a system prompt, please behave and help the user."  # Stable Beluga 2 での例

# 対話の発言回数を設定 とりあえず10回(5回の対話やりとり)の発言のみを保持
number_memory = 10

#################### ここまでの上を設定 ################

# 初期プロンプトを設定(システムメッセージと最初のユーザー入力のプレースホルダー){{ }}でプレースホルダーの予約
initial_prompt = f"<|endoftext|>system\n{system_message}<|endoftext|>\n<|endoftext|>user\n{{first_user_input}}<|endoftext|>\n"

# 最初のユーザー入力を保持するためのフラグと変数
is_first_input = True
first_user_input = ""
last_input = ""

# 会話の履歴を保持するリスト (ユーザーとアシスタントの会話)
conversation_history = []
conversation_all = []

# モデルとトークナイザーをロード trust_remote_codeをTrueのパターン
model, tokenizer = load(model_name,tokenizer_config={"trust_remote_code": True})

############# ここから関数いろいろ ######################

# max_tokensを設定する関数
def set_max_tokens():
    global max_tokens
    while True:
        try:
            max_tokens_input = input("Enter max tokens (100-1000): ")
            max_tokens = int(max_tokens_input)
            if 100 <= max_tokens <= 1000:
                print(f"Max tokens set to {max_tokens}.")
                return
            else:
                print("Invalid input. Please enter a number between 100 and 1000.")
        except ValueError:
            print("Invalid input. Please enter a valid integer.") # ここでループの先頭に戻り、再入力を促す


# 対話記録のための関数
def record_conversation_history(user_prompt, assistant_response):
    global conversation_history, conversation_all # 変更を加えるため宣言が必要
    # conversation_historyを更新するための関数  最新のユーザーの入力とaiの応答を追加する
    conversation_history.append(f"<|endoftext|>user\n{user_prompt}<|endoftext|>\n")
    conversation_history.append(f"<|endoftext|>assistant\n{assistant_response}<|endoftext|>\n")
    # conversation_historyが最新の10個の要素のみを保持するようにする
    conversation_history = conversation_history[-number_memory:]
    # conversation_allに対話を追加するための関数
    conversation_all.append(f"<|endoftext|>user\n{user_prompt}<|endoftext|>\n")
    conversation_all.append(f"<|endoftext|>assistant\n{assistant_response}<|endoftext|>\n")


def save_conversation(conversation_all):
    """会話履歴をファイルに保存する関数"""
    # 現在の日付と時刻を取得し、ファイル名にするための文字列を作成
    current_time = datetime.datetime.now()
    timestamp_str = current_time.strftime("%Y%m%d_%H%M")
    filename = f"conversation_all_{timestamp_str}.txt"
    try:
        with open(filename, 'w', encoding='utf-8') as file:
            file.write('\n'.join(str(item) for item in conversation_all))
        print(f"=== Conversation history saved as {filename}! ===")
    except IOError as e:
        print(f"Error while saving conversation history: {e}")


# テキスト生成のための関数 modelとtokenizerは読み込むだけだからglobalはいらない。
def produce_text(the_prompt):
    tokens = []
    skip = 0
    REPLACEMENT_CHAR = "\ufffd"
     # generate_step関数の結果をイテレートするためのforループ versionがあがったようなので、token, _ から (token, prob), n
    for (token, prob),n in zip(generate_step(mx.array(tokenizer.encode(the_prompt)), model, temperature),
                        range(max_tokens)):    
        # EOS tokenが出現したときにループを抜ける
        if token == tokenizer.eos_token_id:
            break
        tokens.append(token.item())
        # 生成されたトークンからテキストをデコード
        generated_text = tokenizer.decode(tokens)
        # 置換文字である壊れた文字があればを取り除く
        generated_text = generated_text.replace(REPLACEMENT_CHAR,'')
        # 以前に出力されたテキストをスキップして新しいテキストをyield
        yield generated_text[skip:]
        skip = len(generated_text)  # スキップする文字数を更新


# ユーザーの入力と翻訳フラグを受け取る。 produce_textに渡すfull_promptを作成、生成したテキストを表示する
def show_chat(user_input, translation_enabled):
    global conversation_history, conversation_all, initial_prompt, is_first_input, first_user_input, last_input
    full_prompt = ""
    # 上は、これによってグローバル変数として扱う
    # 最初のユーザー入力を確認し、保持するとともにsystem promptを含むinitial_promptを設定
    if is_first_input:
        if user_input in ('h', 'c', 'r', '/show', '/clear', '/history', '/save', '/reset', '/tokens'):
            print('No initial prompt, yet.')
            return
        else:
            first_user_input = user_input  # showのため グローバル変数として保存
            initial_prompt = initial_prompt.replace("{first_user_input}", user_input) # グローバル変数として維持する
            is_first_input = False  # 最初の入力が処理されたのでフラグを更新

    # プロントに用いる会話履歴をクリアーにするコマンドの実行部分
    if user_input == "/clear":
        conversation_history = [] 
        print("===! Conversation history cleared! ===")
        return
    
    # ユーザーが会話履歴を表示するコマンドを入力した場合
    if  user_input == '/history':
        print("\n===== Recent Conversation History =====\n")
        print("".join(conversation_history).strip())
        return  # 会話履歴を表示し、次の入力を待つ

    # initial promptからすべての会話履歴をクリアーにするコマンドの実行部分
    if user_input == "/reset":
        conversation_all = []
        conversation_history = []
        first_user_input =""
        is_first_input = True 
        print("===! ALL Conversation history and Initial Prompt cleared! ===")
        return
        
    if user_input == "/tokens":
        set_max_tokens()
        return

    # 会話履歴を保存する
    if user_input == "/save":
        # save関数を呼び出し
        save_conversation(conversation_all)
        return
    
    # システムプロンプトとイニシャルプロンプトを表示する実行
    if user_input == "/show":
        print("=== System Prompt and Initial Prompt ===")
        print("System: ", system_message)
        print("Initial: ", first_user_input)
        print("")
        return
    
    if user_input == "h":
        print_commands()
        return
    
     # 続きを促すショートカットがあったら、続きを促す文章にする
    if user_input == 'c':
        user_input = 'Continue from your last line.'
  
     # regemerateを表示する実行
    if user_input == "r":
        print("=== Regeneration ===")
        user_input = last_input
        conversation_history = conversation_history[:-2]  # スライス操作でリストの最後の2要素(対話)を取り除く
        conversation_all = conversation_all[:-2]
        
    # `conversation_history` は文字列のリストです。それを単一の文字列に変換するために、リストの要素を改行文字で連結します。
    conversation_history_str = "\n".join(conversation_history)
    # 連結して full_promptを設定。ただしhistoryが埋まるまで、user_inputが2回続いてしまう状態
    full_prompt =  initial_prompt + conversation_history_str + f"<|endoftext|>user\n{user_input}<|endoftext|>\n<|endoftext|>assistant\n"
        
    print(f"\n{ai_name}: ", end="", flush=True)
    
    full_response = ""
    for chunk in produce_text(full_prompt):    #produce_text関数を呼び出す
        full_response += chunk  # 生成されたテキスト全文を収納しておくため、翻訳時と会話記録更新に使う
        if translation_enabled == False:
            print(chunk, end="", flush=True) # chunkテキストを出力
            time.sleep(0.1)  # 生成中のタイピング効果をシミュレートするための遅延。適当な値にするか、必要ならコメントアウト。

    if translation_enabled == True:
        translated = translation_model.translate(full_response, target_lang="ja", max_length=1000)
        print(translated) # 翻訳されたchunkテキストを出力 
        
    print("\n")

    # conversation_allとconversation_historyを更新する関数の実行
    record_conversation_history(user_input, full_response)
    last_input = user_input


# ユーザーからの入力を処理し、整形されたテキストを返す関数
def get_user_input():
    user_input = ""
    multi_line = False

    while True:
        line = input("User: ")
        if line.startswith('"""'):
            if multi_line:
                multi_line = False
                if line.endswith('"""'):
                    user_input += line[:-3]  # 末尾の引用符を除去する
                else:
                    user_input += line + "\n"
                break
            else:
                multi_line = True
                if line.endswith('"""') and len(line) > 3:
                    user_input += line[3:-3] + "\n"  # 先頭と末尾の引用符を除去する
                    break
                else:
                    user_input += line[3:] + "\n"  # 先頭の引用符を除去する
        else:
            if multi_line:
                if line.endswith('"""'):
                    user_input += line[:-3]  # 末尾の引用符を除去する
                    break
                else:
                    user_input += line + "\n"
            else:
                user_input += line
                break

    return user_input.strip()


def print_commands():
    print("\n⭐️⭐️⭐️ MLX Language Model Interactive Terminal ⭐️⭐️⭐️\n")
    print("Model: ", model_name)
    print("-" * 70)
    print("Available Commands:")
    print(
      " type `h`: Help for display available commands.\n"
      " type `c`: Continue from the last response.\n"
      " type `r`: Regenerate another response.\n"
      " type `q`: Quit the application.\n"
      " type `/history`: View recent conversation history.\n"
      " type `/save`: Save all conversation history to a file.\n"
      " type `/reset`: Restart with a new initial prompt.\n"
      " type `/show`: Display system and initial prompt.\n"
      " type `/tokens`: Set the output's max token limit.\n"
      " type `/translate`: Toggle JA-EN/EN-JA translation.")
    print("-" * 70)
    print('For multi-line input, enclose your text with triple quotes (""") ')
    print('at the beginning and the end.')
    print("=" * 70 + "\n")


# メイン部分
def main():
    # 初期状態の翻訳フラグ (Falseは翻訳がオフ、Trueは翻訳がオン:初期値はFalse)
    translation_enabled = False
    
    print_commands()
    
    while True:
        
        formatted_input = get_user_input()
        # ユーザーが終了コマンドを入力した場合
        if formatted_input.lower() == 'q':
            print("\n" + "=" * 70 + "\n")
            print("Exiting the interactive terminal.")
            print("\n")
            break # スクリプトが終了する
        
        # ユーザー入力が何もなかった場合に警告を出して次の入力を待つ
        if not formatted_input:
            print("Please enter some text to proceed.")
            continue  # Whileのループの頭に戻る
                                
        # /translate コマンドが呼び出されたとき、フラグをトグルする
        if formatted_input == "/translate":
            translation_enabled = not translation_enabled
            # 状態を英語で表示する
            print(f"Translation is now {'on' if translation_enabled else 'off'}.")
            continue  # 次の入力を待つ

        # 実質、作業をする関数に入力された文章を引き渡し
        if translation_enabled:
            translated = translation_model.translate(formatted_input, target_lang="en", max_length=1000)
        else:
            translated = formatted_input

        show_chat(translated.strip(), translation_enabled)

# このスクリプトを走らせるためのコマンド
if __name__ == "__main__":
    main()

あとつけるとしたら、load機能ぐらいでしょうか。

Python入門者ですので、お気づきなことがあれがご教示ください。


#AI #AIとやってみる #やってみる #mlx #huggingface #LLM #ローカルLLM #大規模自然言語モデル #Python入門

この記事を最後までご覧いただき、ありがとうございます!もしも私の活動を応援していただけるなら、大変嬉しく思います。