見出し画像

AutoGenをガッツリ試してみる:SQLを利用してみよう

概要

こんなアドベントカレンダーを執筆していたら時間が経ってしまいましたが、引き続きAutoGenを使ってみます。これまでは、テキストで会話しつつ分析をPythonにやらせる流れでしたが、もう少し違うツールも使ってみたくなります。そこで本日はAgentでSQLの仕組みを使ってみようと思います。

原理

AutoGenから直接SQLを叩く仕組みを作るためには、保持しているテーブルをチェックしたり、SQLのクエリを生成して実行し、最後にテキストやグラフ・テーブルなど必要な形式で出力するなど、複数のステップに分かれるため、少し大変です。今回はLangChainの仕組みを使うことで簡易的にAutoGenからSQLを扱えるようにしてみましょう。

より具体的には、Function Callingとして、SQLを利用する仕組みを呼べるようにしておくことで実現します。SQLを利用する仕組みは、LangChainの中にあるagent toolkitを利用できます。ちなみに、DBを用意するのが楽だったのでSQLiteを利用していますが、LangChainで利用できるものであればMySQLでもPostgreSQLでも利用できると思います。

準備

いつものようにライブラリをインストールしていきましょう。今回は以下の他にも使うライブラリがありそうで、必要に応じてインストールしてください。

pip install pyautogen==0.2.2
pip install langchain==0.0.352

はじめに、Agentから利用するデータベースを作成しましょう。とりあえず簡易的なものを作っていきます。生徒テーブル、科目テーブル、テストの点数のテーブルを作っておきます。

import sqlite3

connection = sqlite3.connect('sample.db')

cursor = connection.cursor()

create_student_table = """CREATE TABLE students (
  id INTEGER,
  name TEXT NOT NULL,
  age INTEGER NOT NULL
);"""

cursor.execute(create_student_table)

create_courses_table = """CREATE TABLE courses (
  id INTEGER,
  course TEXT NOT NULL
);"""

cursor.execute(create_courses_table)

create_grades_table = """CREATE TABLE grades (
  name INTEGER NOT NULL,
  course TEXT NOT NULL,
  grade FLOAT NOT NULL
);"""

cursor.execute(create_grades_table)

students = [(1, "Taro", 20), (2, "Hanako", 19), (3, "Jiro", 21)]
courses = [(1, "Math"), (2, "English"), (3, "Physics")]
grades = [
    ("Taro", "Math", 85.0),
    ("Taro", "English", 90.0),
    ("Taro", "Physics", 88.0),
    ("Hanako", "Math", 75.0),
    ("Hanako", "English", 80.0),
    ("Hanako", "Physics", 82.0),
    ("Jiro", "Math", 92.0),
    ("Jiro", "English", 95.0),
    ("Jiro", "Physics", 89.0),
]

insert_students = """INSERT INTO students (id, name, age) VALUES (?, ?, ?);"""
cursor.executemany(insert_students, students)

insert_courses = """INSERT INTO courses (id, course) VALUES (?, ?);"""
cursor.executemany(insert_courses, courses)

insert_grades = """INSERT INTO grades (name, course, grade) VALUES (?, ?, ?);"""
cursor.executemany(insert_grades, grades)

connection.commit()
cursor.close()

Functionsの作成

はじめにLangChainの仕組みを使いデータベースを読み込みます。LangChainの提供するSQLDatabaseToolkitを利用することで、自然言語でのSQLの実行がきます。

from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.utilities.sql_database import SQLDatabase

os.environ["OPENAI_API_KEY"] = "<API Key>"

sql = SQLDatabase.from_uri('sqlite:///sample.db')
llm = OpenAI(model_name="gpt-3.5-turbo-instruct", temperature=0)
toolkit = SQLDatabaseToolkit(db=sql, llm=llm)

以下のようにして、作成したtoolkitを呼び出すためのFunction Callingのスキーマを作ってみましょう。toolkitのパラメータをgenerate_llm_config関数で変換しtool_schemaに保存します。また、LangChainのtoolkitを呼び出すのはtool._runであり、これをfunction_mapに登録します。

def generate_llm_config(tool):
    # LangChainのAgentツールからFunction schemaを作成
    function_schema = {
        "name": tool.name.lower().replace (' ', '_'),
        "description": tool.description,
        "parameters": {
            "type": "object",
            "properties": {},
            "required": [],
        },
    }

    if tool.args is not None:
        function_schema["parameters"]["properties"] = tool.args

    return function_schema

# Now use AutoGen with Langchain Tool Bridgre
tools = []
function_map = {}

for tool in toolkit.get_tools(): #debug_toolkit if you want to use tools directly
    tool_schema = generate_llm_config(tool)
    tools.append(tool_schema)
    function_map[tool.name] = tool._run

Agentの作成

いつも通りAgentを作っていきましょう。複数の会話が連続する中で、タスクを終了した時にOpenAIのAgentがTERMINATEを返すのを読み込んで、UserProxyは動作を停止します。また、先ほど作成したfunction_mapを登録します。

import autogen

config_list = [{
    'model': 'gpt-4-1106-preview', 
    'api_key': '<API Key>'
}]

# Construct the llm_config
llm_config = {
  "functions": tools,
  "config_list": config_list,  # Assuming you have this defined elsewhere
  "timeout": 120,
}

user_proxy = autogen.UserProxyAgent(
    name="user_proxy",
    is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"),
    human_input_mode="NEVER",
    max_consecutive_auto_reply=10,
    code_execution_config={"work_dir": "tmp", "use_docker": False},
)

# Register the tool and start the conversation
user_proxy.register_function(
    function_map = function_map
)

chatbot = autogen.AssistantAgent(
    name="chatbot",
    system_message="For coding tasks, only use the functions you have been provided with. Reply TERMINATE when the task is done.",
    llm_config=llm_config,
)

実行結果(テーブルの概要)

まずはテーブルの内容の説明をしてもらいました。長いので結果のみ出力します。なお、SQLの場合利用できるツールは、sql_db_query、sql_db_schema、sql_db_list_tables、sql_db_query_checkerがありますが、今回はsql_db_list_tablesを利用してテーブルリストを抽出してからsql_db_schemaでテーブルの詳細を見ているようで、多段階での作業ができることを確認できました。

Q: どんなデータを持っていますか
A: 以下のスキーマとサンプルデータを持つ3つのテーブルがデータベースに存在します:

1. `courses` テーブル:
   - `id`: コースID(整数)
   - `course`: コース名(テキスト、null不可)

   サンプルデータ:
   ```
   | id | course  |
   |----|---------|
   | 1  | Math    |
   | 2  | English |
   | 3  | Physics |
   ```

2. `grades` テーブル:
   - `name`: 学生名(整数、null不可)
   - `course`: コース名(テキスト、null不可)
   - `grade`: 成績(浮動小数点数、null不可)

   サンプルデータ:
   ```
   | name  | course  | grade |
   |-------|---------|-------|
   | Taro  | Math    | 85.0  |
   | Taro  | English | 90.0  |
   | Taro  | Physics | 88.0  |
   ```

3. `students` テーブル:
   - `id`: 学生ID(整数)
   - `name`: 学生名(テキスト、null不可)
   - `age`: 年齢(整数、null不可)

   サンプルデータ:
   ```
   | id | name   | age |
   |----|--------|-----|
   | 1  | Taro   | 20  |
   | 2  | Hanako | 19  |
   | 3  | Jiro   | 21  |
   ```

これらの情報を使って、さらにクエリを作成したり、データを分析することができます。何か特定の情報が必要であれば、質問してください。

実行結果(データベースへの数値の質問)

データベースに対して、生徒は何人ですか?と質問を投げてみましょう。

user_proxy.initiate_chat(
    chatbot,
    message="生徒は何人ですか?",
    llm_config=llm_config,
)

結果は以下になりました。

現在、生徒は3人です。

続いて、微妙に複雑なクエリを投げてみましょう。

Q: 英語の点数が一番高い生徒は
A: The student with the highest English score is Taro.

終わりに

このようにFunction Callingを併用することで、LangChainの資産の流用ができます。今回の実験では、その仕組を利用してSQLを扱えるようにしてみました。DBは様々なシーンで使われている割に、特に非エンジニアの場合、クエリを書くのが苦手な人も多いので、こういった機能は有用ではないでしょうか。

いいなと思ったら応援しよう!