Pydanticを利用した型安全なFunction calling

Pydanticを活用すれば、型安全かつスキーマ定義も容易に実現ができそうということが分かりましたので共有です。この記事の内容は、以下のツイートに着想を受けました。色々な手法を模索されていて必見です。


本記事でのプログラムは、自分なりに少しアレンジをして実装しています。コードは以下で確認ができます。


実装例

百聞は一見にしかずなので、実際にコードを見ていきます。

from pydantic import BaseModel, Field
import openai
from enum import Enum
import json

# 単位フィールド用のEnumを定義
class TemperatureUnit(str, Enum):
    celsius = "celsius"
    fahrenheit = "fahrenheit"

# 関数引数のためのPydanticモデルを定義
class WeatherSearch(BaseModel):
    "指定された場所の現在の天気を取得"
    location: str = Field(..., description="都市の名称、例:東京")
    unit: TemperatureUnit = Field(..., description="温度の単位")

    def execute(self):
        # これはダミーです。実際のAPIを呼び出し、適切にエラーを処理する必要があります
        weather_info = {
            "location": self.location,
            "temperature": "22",
            "unit": self.unit.value,
            "forecast": ["晴れ", "風が強い"],
        }
        return json.dumps(weather_info)

# 会話を実行
def run_conversation(message: str):
    # 関数のスキーマを定義
    schema = WeatherSearch.schema()
    print(f"[スキーマ]\n{schema}\n")
    function_schema = {
        "name": schema["title"],
        "description": schema["description"],
        "parameters": schema
    }

    # APIへの最初の呼び出し
    response = openai.ChatCompletion.create(
        model="gpt-3.5-turbo-0613",
        messages=[{"role": "user", "content": message}],
        functions=[function_schema],
        function_call="auto",
    )
    print(f"[最初のレスポンス]\n{response}\n")

    message = response["choices"][0]["message"]

    # モデルが関数を呼び出したいかどうかを確認
    if message.get("function_call"):
        function_name = message["function_call"]["name"]
        function = WeatherSearch(**json.loads(message["function_call"]["arguments"]))

        # 関数を呼び出す
        function_response = function.execute()

        # APIへの二回目の呼び出し
        second_response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo-0613",
            messages=[
                {"role": "user", "content": "東京の天気はどうですか?"},
                message,
                {
                    "role": "function",
                    "name": function_name,
                    "content": function_response,
                },
            ],
        )
        print(f"[二回目のレスポンス]\n{second_response}\n")
        return second_response["choices"][0]["message"]["content"]
    return message["content"]
run_conversation("東京の天気はどうですか?")

[スキーマ]
{'title': 'WeatherSearch', 'description': '指定された場所の現在の天気を取得', 'type': 'object', 'properties': {'location': {'title': 'Location', 'description': '都市の名称、例:東京', 'type': 'string'}, 'unit': {'description': '温度の単位', 'allOf': [{'$ref': '#/definitions/TemperatureUnit'}]}}, 'required': ['location', 'unit'], 'definitions': {'TemperatureUnit': {'title': 'TemperatureUnit', 'description': 'An enumeration.', 'enum': ['celsius', 'fahrenheit'], 'type': 'string'}}}

[最初のレスポンス]
{ "id": "chatcmpl-7RsC1rhFNGSW4pdISarnBiHRuJw0D", "object": "chat.completion", "created": 1686877529, "model": "gpt-3.5-turbo-0613", "choices": [ { "index": 0, "message": { "role": "assistant", "content": null, "function_call": { "name": "WeatherSearch", "arguments": "{\n \"location\": \"\u6771\u4eac\",\n \"unit\": \"celsius\"\n}" } }, "finish_reason": "function_call" } ], "usage": { "prompt_tokens": 103, "completion_tokens": 26, "total_tokens": 129 } }

[二回目のレスポンス]
{ "id": "chatcmpl-7RsC2JBC7hX6FAYokjpQRikiIPWAr", "object": "chat.completion", "created": 1686877530, "model": "gpt-3.5-turbo-0613", "choices": [ { "index": 0, "message": { "role": "assistant", "content": "\u6771\u4eac\u306e\u5929\u6c17\u306f22\u2103\u3067\u3001\u6674\u308c\u3067\u98a8\u304c\u5f37\u3044\u3067\u3059\u3002" }, "finish_reason": "stop" } ], "usage": { "prompt_tokens": 106, "completion_tokens": 24, "total_tokens": 130 } }

'東京の天気は22℃で、晴れで風が強いです。'

出力結果
run_conversation("面白い回文を教えてください")

[スキーマ]
{'title': 'WeatherSearch', 'description': '指定された場所の現在の天気を取得', 'type': 'object', 'properties': {'location': {'title': 'Location', 'description': '都市の名称、例:東京', 'type': 'string'}, 'unit': {'description': '温度の単位', 'allOf': [{'$ref': '#/definitions/TemperatureUnit'}]}}, 'required': ['location', 'unit'], 'definitions': {'TemperatureUnit': {'title': 'TemperatureUnit', 'description': 'An enumeration.', 'enum': ['celsius', 'fahrenheit'], 'type': 'string'}}}

[最初のレスポンス]
{ "id": "chatcmpl-7RsC327nJwwTKbAAuRkGwoYViX73V", "object": "chat.completion", "created": 1686877531, "model": "gpt-3.5-turbo-0613", "choices": [ { "index": 0, "message": { "role": "assistant", "content": "\u300c\u305f\u3051\u3084\u3076\u3084\u3051\u305f\u300d" }, "finish_reason": "stop" } ], "usage": { "prompt_tokens": 103, "completion_tokens": 11, "total_tokens": 114 } }

'「たけやぶやけた」'

出力結果


解説

Pydanticで関数のパラメータを定義

以下のように定義をすることで、型安全に関数を実行できます。またWeatherArgsのメソッドとして定義することで、パラメータを受け取りながらインスタンスを立てて、そのまま実行ができる形にしています。

# 単位フィールド用のEnumを定義
class TemperatureUnit(str, Enum):
    celsius = "celsius"
    fahrenheit = "fahrenheit"

# 関数引数のためのPydanticモデルを定義
class WeatherSearch(BaseModel):
    "指定された場所の現在の天気を取得"
    location: str = Field(..., description="都市の名称、例:東京")
    unit: TemperatureUnit = Field(..., description="温度の単位")

    def execute(self):
        # これはダミーです。実際のAPIを呼び出し、適切にエラーを処理する必要があります
        weather_info = {
            "location": self.location,
            "temperature": "22",
            "unit": self.unit.value,
            "forecast": ["晴れ", "風が強い"],
        }
        return json.dumps(weather_info)


Pydanticでスキーマ生成

BaseModel.schema()で、function callingに必要な情報を生成することができます。

  • name: schema["title"]

    • クラス名を生成

  • description: schema["description"]

    • クラス名の直下に書かれたコメントがdescriptionとなる

  • parameters: schema

    • OpenAI指定の形式ではないものの各パラメータ情報は盛り込まれているため正常に挙動する

def run_conversation(message: str):
    # 関数のスキーマを定義
    schema = WeatherArgs.schema()
    print(f"[スキーマ]\n{schema}\n")
    function_schema = {
        "name": schema["title"],
        "description": schema["description"],
        "parameters": schema
    }

shemaは以下のようになっています。

{
	"title": "WeatherSearch",
	"description": "指定された場所の現在の天気を取得",
	"type": "object",
	"properties": {
		"location": {
			"title": "Location",
			"description": "都市の名称、例:東京",
			"type": "string"
		},
		"unit": {
			"description": "温度の単位",
			"allOf": [
				{
					"$ref": "#/definitions/TemperatureUnit"
				}
			]
		}
	},
	"required": [
		"location",
		"unit"
	],
	"definitions": {
		"TemperatureUnit": {
			"title": "TemperatureUnit",
			"description": "An enumeration.",
			"enum": [
				"celsius",
				"fahrenheit"
			],
			"type": "string"
		}
	}
}


関数の実行

まずPydanticで定義したクラスのインスタンスを立てて、そこからexecuteメソッドを実行して関数の実行結果を取得します。

    message = response["choices"][0]["message"]

    # モデルが関数を呼び出したいかどうかを確認
    if message.get("function_call"):
        function_name = message["function_call"]["name"]
        function = WeatherSearch(**json.loads(message["function_call"]["arguments"]))

        # 関数を呼び出す
        function_response = function.execute()


おわりに

FastAPIと組み合わせるとやり方次第では、SwaggerUIにもうまく反映ができるので良さそうだなと思いました。簡易的に実装ができるのでお気に入りです。ぜひ試してみてください!


いいなと思ったら応援しよう!