Pydanticを用いたOpenAI Assistant API内における Function Callingの型安全な利用
初投稿です.
この記事では,先週のOpenAI devdayで発表されたAssistant内の機能である「Function calling」を,型安全な関数呼び出し方法を提供するPydanticを使用して実装したコードを共有します.
Pydanticを用いることでコードの可読性,保守性が向上し,特にFunctionの管理が容易となります.
本記事の作成にあたっては以下の記事を参考にさせていただきました.Pydanticに関する詳しい内容はそちらを参照してください.
実装
必要なパッケージのインストール
pip install openai pydantic
Function(クラス)の用意
今回は公式のAPI reference にある天気APIの例を使用します.
Functionを複数持ちたい場合はFunctionごとにクラスを作成し,CLASS_MAPに加えることで実装できます.
from pydantic import BaseModel, Field
class GetWeather(BaseModel):
"現在の天気を取得する関数"
location: str = Field(..., description="場所 (例:東京)")
unit: str = Field(..., description="温度の単位 (c:摂氏, f:華氏)")
# ダミーの関数を用意
def execute(self):
return f"{self.location}の天気は22{self.unit}です。"
# マッピングをする
CLASS_MAP = {
"GetWeather": GetWeather,
}
toolsとして用意する
FunctionをAssistantに渡すために整形します.
schemas = [v.model_json_schema() for k, v in CLASS_MAP.items()]
function_schemas = []
# スキーマをAPIの仕様に合わせて整える
for schema in schemas:
function_schema = {
"name": schema["title"],
"description": schema["description"],
"parameters": {
"type": "object",
"properties": schema.get("properties", {}),
"required": list(schema.get("properties", {}).keys()),
},
}
function_schemas.append(function_schema)
tools = [{"type": "function", "function": f} for f in function_schemas]
前準備
アシスタントを作成しスレッドを用意します.
from openai import OpenAI
# クライアントの準備
client = OpenAI()
# アシスタントの作成
assistant = client.beta.assistants.create(
instructions="あなたは優秀なお天気ボットです。 提供されている関数を使用して質問に答えます。",
model="gpt-4-1106-preview",
tools=tools,
)
# スレッドの準備
thread = client.beta.threads.create()
会話の実行
同一スレッド内で会話を行う関数を定義しています.
functionを呼び出す必要があればfunctionを実行し,そうでない場合は通常の会話が行われます.
返り値は最新の返答となります.
def run_conversation(input: str):
# ユーザーメッセージの追加
message = client.beta.threads.messages.create(
thread_id=thread.id, role="user", content=input
)
# アシスタントにリクエスト
run = client.beta.threads.runs.create(
thread_id=thread.id,
assistant_id=assistant.id,
)
# 実行状況の確認
run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id)
while run.status == "in_progress":
print("waiting...")
time.sleep(1)
run = client.beta.threads.runs.retrieve(
thread_id=thread.id, run_id=run.id
)
if run.status == "requires_action":
# requires_actionのパラメータの取得
tool_id = run.required_action.submit_tool_outputs.tool_calls[0].id
tool_function_name = (
run.required_action.submit_tool_outputs.tool_calls[0].function.name
)
tool_function_arguments = json.loads(
run.required_action.submit_tool_outputs.tool_calls[
0
].function.arguments
)
print("id:", tool_id)
print("name:", tool_function_name)
print("arguments:", tool_function_arguments)
# リクエストされた関数の実行
tool_function_output = CLASS_MAP[tool_function_name](
**tool_function_arguments
).execute()
# 関数の出力を提出
run = client.beta.threads.runs.submit_tool_outputs(
thread_id=thread.id,
run_id=run.id,
tool_outputs=[
{
"tool_call_id": tool_id,
"output": tool_function_output,
}
],
)
run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id)
while run.status != "completed":
time.sleep(1) # サーバーへの問い合わせ間隔を1秒に設定
run = client.beta.threads.runs.retrieve(
thread_id=thread.id, run_id=run.id
) # ステータスを更新
messages = client.beta.threads.messages.list(
thread_id=thread.id, order="asc"
)
# 最新のメッセージを返す
return messages.data[-1].content[0].text.value
実行結果
Functionが呼べるか確認します.
run_conversation("東京の天気は何度?")
記憶の保持も試しました.
run_conversation("ここまでの会話をまとめてください。")
終わりに
Functionをいじったり追加する際にクラスの部分を書き換えるだけで済むのが楽ですね.
問題点,改善点やコメント等あればご気軽によろしくお願いします!