FlaskでChatGPTのレスポンスをストリーミング表示
PythonでChatGPTのツールを作っていて、回答をストリーミング表示(ヌルヌルと文字が出てくるやつ)させたくて、その調査ログです。
Flaskでストリーミング表示
そもそもFlaskをよく知らないので、GPTに関係なくストリーミング表示させる方法をGPTに聞いてみました。
GPTに投げたお願い
Flaskでストリーミング表示を学んでいます。以下のようなサンプルを書いてください。
* socketで実現する。roomも使う。
* HTMLのボタンを押したら、ボタンを無効にしてPythonに処理が行く。
* Pythonでは、5秒ウェイとする。1秒ごとに経過時間をクライアントに伝え、Webページでそれを表示する。
* 5秒経過したら、Pythonから「終了」というメッセージがHTMLに飛ぶ
* 通信を終了する。ボタンも有効に戻す。
* そのメッセージをHTMLに表示する
結果、PythonとHTMLでそれぞれ以下のような返答が来ました。一発でエラーなく回って感動♪
PIP
pip install flask flask-socketio eventlet
app.py
from flask import Flask, render_template
from flask_socketio import SocketIO, emit, join_room
import time
app = Flask(__name__)
socketio = SocketIO(app)
@app.route('/')
def index():
return render_template('index.html')
@socketio.on('start_process', namespace='/test')
def start_process(message):
join_room('room')
for i in range(1, 6):
time.sleep(1)
emit('time_update', {'message': f'{i} seconds passed.'}, room='room')
emit('process_completed', {'message': '終了'}, room='room')
if __name__ == '__main__':
socketio.run(app, debug=True)
templates/index.html
<!DOCTYPE html>
<html lang="ja">
<head>
<meta charset="UTF-8">
<title>Streaming Example</title>
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.0/socket.io.min.js"></script>
<script type="text/javascript">
var socket = io.connect('http://' + document.domain + ':' + location.port + '/test');
socket.on('connect', function() {
socket.emit('joined', {});
});
socket.on('time_update', function(msg) {
document.getElementById('status').innerHTML = msg.message;
});
socket.on('process_completed', function(msg) {
document.getElementById('status').innerHTML = msg.message;
document.getElementById('startButton').disabled = false;
});
function startProcess() {
socket.emit('start_process', {});
document.getElementById('startButton').disabled = true;
}
</script>
</head>
<body>
<button id="startButton" onclick="startProcess()">Start</button>
<div id="status"></div>
</body>
</html>
実行結果

Langchainで生成AIを動かす
API_KEYを「.env」というファイル名で保存して、以下のプログラムを実行。

test_chatgpt.py
import os
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
llm = ChatOpenAI(
model_name = "gpt-3.5-turbo",
temperature = 0,
streaming = False,
)
response = llm(messages=[HumanMessage(content="こんにちわ")])
print(response)
これを回してターミナルにメッセージが出たら成功です。次はヌルヌル表示にチャレンジします。
GPTの返答をストリーミング表示
まずはライブラリのインストール。ここで注意点!langchainのバージョンが「0.0.142(以降?)」じゃないと、CallbackManagerなる関数が使えない模様。PIPするときに注意してください。
pip install langchain==0.0.142
まずはコールバック関数を定義します。「https://ict-worker.com/ai/langchain-stream.html」の記事が非常にわかりやすかったので、ほぼその中身を流用させていただいております。
ChatOpenAIでStreamingをTrueにすると、ChatGPTが返答トークンを発行するたびに「on_llm_new_token」が呼ばれ、自作コールバック関数(後述)が動くという仕様です。
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import CallbackManager, BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
class mycbhandler(BaseCallbackHandler):
streaming_handler = None
def __init__(self, jisaku_callbackfunction):
#自作のコールバック関数を登録
self.streaming_handler = jisaku_callbackfunction
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self.streaming_handler(token)
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
pass
def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
pass
def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> None:
class_name = serialized["name"]
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
print("\n\033[1m> Finished chain.\033[0m")
def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
pass
def on_tool_start(self,serialized: Dict[str, Any], input_str: str, **kwargs: Any, ) -> None:
pass
def on_agent_action(self, action: AgentAction, color: Optional[str] = None, **kwargs: Any) -> Any:
print(action)
def on_tool_end(self, output: str, color: Optional[str] = None, observation_prefix: Optional[str] = None, llm_prefix: Optional[str] = None, **kwargs: Any) -> None:
print(output)
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
pass
def on_text(self, text: str, color: Optional[str] = None, end: str = "", **kwargs: Optional[str]) -> None:
print(text)
def on_agent_finish(self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any) -> None:
print(finish.log)
次に、自作コールバック関数を作り(単にプリントするだけ)、ChatOpenAIにそのコールバック関数を渡してあげます。
import os
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
# 自作コールバック関数(単にプリントするだけ)
def handle_token(token):
print('\033[36m' + token + '\033[0m')
llm = ChatOpenAI(
streaming = True,
callback_manager = CallbackManager([mycbhandler(handle_token)]),
verbose = True,
temperature = 0
)
response = llm(messages=[HumanMessage(content="こんにちわ")])
print(response)
これら2つを足し合わせて実行してみてください。例えば、以下のような感じです。(単に上の2つを足し合わせただけです)
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import CallbackManager, BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
#--------------------------------------------------------#
class mycbhandler(BaseCallbackHandler):
streaming_handler = None
def __init__(self, jisaku_callbackfunction):
#自作のコールバック関数を登録
self.streaming_handler = jisaku_callbackfunction
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self.streaming_handler(token)
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
pass
def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
pass
def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> None:
class_name = serialized["name"]
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
print("\n\033[1m> Finished chain.\033[0m")
def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
pass
def on_tool_start(self,serialized: Dict[str, Any], input_str: str, **kwargs: Any, ) -> None:
pass
def on_agent_action(self, action: AgentAction, color: Optional[str] = None, **kwargs: Any) -> Any:
print(action)
def on_tool_end(self, output: str, color: Optional[str] = None, observation_prefix: Optional[str] = None, llm_prefix: Optional[str] = None, **kwargs: Any) -> None:
print(output)
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
pass
def on_text(self, text: str, color: Optional[str] = None, end: str = "", **kwargs: Optional[str]) -> None:
print(text)
def on_agent_finish(self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any) -> None:
print(finish.log)
#--------------------------------------------------------#
import os
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
def handle_token(token):
print('\033[36m' + token + '\033[0m')
if __name__ == '__main__' :
llm = ChatOpenAI(
streaming = True,
callback_manager = CallbackManager([mycbhandler(handle_token)]),
verbose = True,
temperature = 0
)
response = llm(messages=[HumanMessage(content="こんにちわ")])
print(response)
ターミナルにヌルヌルと文字が表示されれば成功です。
Webページでヌルヌル実装
いよいよ実装です。
最初の「Flaskでストリーミング表示」にある`start_process`に、以下のような部分があります。
for i in range(1, 6):
time.sleep(1)
emit('time_update', {'message': f'{i} seconds passed.'}, room='room')
ここのemit部分を、自作コールバック関数に使いまわします。
自作コールバック関数
ai_message = ''
def handle_token(token):
global ai_message
ai_message = ai_message + token
emit('time_update', {'message': ai_message}, room='room')
また、上述のfor文はコメントアウトして、以下の2行を付け足します。(llmは外(global)で定義済みという前提)
global llm
response = llm(messages=[HumanMessage(content="こんにちわ")])
これでストリーミング表示されるはずです!
ソースコード全文は以下の通りです。(index.htmlは全く変えず)
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import CallbackManager, BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
#--------------------------------------------------------#
class mycbhandler(BaseCallbackHandler):
streaming_handler = None
def __init__(self, jisaku_callbackfunction):
#自作のコールバック関数を登録
self.streaming_handler = jisaku_callbackfunction
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self.streaming_handler(token)
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
pass
def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
pass
def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> None:
class_name = serialized["name"]
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
print("\n\033[1m> Finished chain.\033[0m")
def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
pass
def on_tool_start(self,serialized: Dict[str, Any], input_str: str, **kwargs: Any, ) -> None:
pass
def on_agent_action(self, action: AgentAction, color: Optional[str] = None, **kwargs: Any) -> Any:
print(action)
def on_tool_end(self, output: str, color: Optional[str] = None, observation_prefix: Optional[str] = None, llm_prefix: Optional[str] = None, **kwargs: Any) -> None:
print(output)
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
pass
def on_text(self, text: str, color: Optional[str] = None, end: str = "", **kwargs: Optional[str]) -> None:
print(text)
def on_agent_finish(self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any) -> None:
print(finish.log)
#--------------------------------------------------------#
import os
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
ai_message = ''
def handle_token(token):
global ai_message
ai_message = ai_message + token
emit('time_update', {'message': ai_message}, room='room')
llm = ChatOpenAI(
streaming=True,
callback_manager=CallbackManager([mycbhandler(handle_token)]),
verbose=True,
temperature=0
)
#--------------------------------------------------------#
from flask import Flask, render_template
from flask_socketio import SocketIO, emit, join_room
import time
app = Flask(__name__)
socketio = SocketIO(app)
@app.route('/')
def index():
return render_template('index.html')
@socketio.on('start_process', namespace='/test')
def start_process(message):
join_room('room')
# for i in range(1, 6):
# time.sleep(1)
# emit('time_update', {'message': f'{i} seconds passed.'}, room='room')
#以下の2行を付け足し
global llm
response = llm(messages=[HumanMessage(content="こんにちわ")])
emit('process_completed', {'message': response.content}, room='room')
if __name__ == '__main__':
socketio.run(app, debug=True)
おまけ:自作openaiクラス
これらのことを毎回書くのが面倒なので、自作クラスを作りました。上にも書きましたが、langchalnは0.0.142(以降?)ですのでご注意を。
pip install langchain==0.0.142
from typing import Any, Dict, List, Optional, Union
import os
from langchain.prompts.chat import (
ChatPromptTemplate ,
SystemMessagePromptTemplate ,
MessagesPlaceholder ,
HumanMessagePromptTemplate ,
)
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationChain
from langchain.callbacks.base import CallbackManager, BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage
#--- コールバッククラス -----------------------------------#
class mycbhandler(BaseCallbackHandler):
streaming_handler = None
def __init__(self, jisaku_callbackfunction):
#自作のコールバック関数を登録
self.streaming_handler = jisaku_callbackfunction
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self.streaming_handler(token)
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
pass
def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
pass
def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> None:
class_name = serialized["name"]
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
print("\n\033[1m> Finished chain.\033[0m")
def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
pass
def on_tool_start(self,serialized: Dict[str, Any], input_str: str, **kwargs: Any, ) -> None:
pass
def on_agent_action(self, action: AgentAction, color: Optional[str] = None, **kwargs: Any) -> Any:
print(action)
def on_tool_end(self, output: str, color: Optional[str] = None, observation_prefix: Optional[str] = None, llm_prefix: Optional[str] = None, **kwargs: Any) -> None:
print(output)
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
pass
def on_text(self, text: str, color: Optional[str] = None, end: str = "", **kwargs: Optional[str]) -> None:
print(text)
def on_agent_finish(self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any) -> None:
print(finish.log)
#--------------------------------------------------------#
#--- メインクラス -----------------------------------#
class myopenai :
template = None
mycallbackfunc = None
def __init__(self) :
pass
def set_prompt(self, txt:str) :
self.prompt = ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template (txt) ,
MessagesPlaceholder (variable_name="history") ,
HumanMessagePromptTemplate.from_template ("{input}") ,
])
def set_mycallbackfunction(self, mycallbackfunc:Any) :
self.mycallbackfunc = mycallbackfunc
#会話の読み込みを行う関数を定義
def load_conversation(self, model:str, streaming:bool=True):
llm = ChatOpenAI(
model_name = model,
temperature = 0,
streaming = streaming,
callback_manager = CallbackManager([mycbhandler(self.mycallbackfunc)]),
verbose = True,
)
memory = ConversationBufferMemory(return_messages=True)
# print(f'---{self.prompt}---')
conversation = ConversationChain(
memory = memory,
prompt = self.prompt,
llm = llm
)
return conversation
#--------------------------------------------------------#
if __name__ == '__main__' :
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
mo = myopenai()
#まず普通に出す(ストリーミングなし)
mo.set_prompt('あなたは精神科医です。私の悩みを聞いて、適切にアドバイスをしてください。')
conv = mo.load_conversation(model='gpt-3.5-turbo', streaming=False)
ans = conv.predict(input='こんにちわ')
print(ans)
#ストリーミングあり
def handle_token(token):
print('\033[36m' + token + '\033[0m')
mo.set_mycallbackfunction(handle_token)
mo.set_prompt('あなたは精神科医です。私の悩みを聞いて、適切にアドバイスをしてください。')
conv = mo.load_conversation(model='gpt-3.5-turbo', streaming=True)
ans = conv.predict(input='お腹が痛い')
print(ans)
ans = conv.predict(input='別のアドバイスありますか?')
print(ans)