AutoGenをガッツリ試してみる:SQLを利用してみよう
概要
こんなアドベントカレンダーを執筆していたら時間が経ってしまいましたが、引き続きAutoGenを使ってみます。これまでは、テキストで会話しつつ分析をPythonにやらせる流れでしたが、もう少し違うツールも使ってみたくなります。そこで本日はAgentでSQLの仕組みを使ってみようと思います。
原理
AutoGenから直接SQLを叩く仕組みを作るためには、保持しているテーブルをチェックしたり、SQLのクエリを生成して実行し、最後にテキストやグラフ・テーブルなど必要な形式で出力するなど、複数のステップに分かれるため、少し大変です。今回はLangChainの仕組みを使うことで簡易的にAutoGenからSQLを扱えるようにしてみましょう。
より具体的には、Function Callingとして、SQLを利用する仕組みを呼べるようにしておくことで実現します。SQLを利用する仕組みは、LangChainの中にあるagent toolkitを利用できます。ちなみに、DBを用意するのが楽だったのでSQLiteを利用していますが、LangChainで利用できるものであればMySQLでもPostgreSQLでも利用できると思います。
準備
いつものようにライブラリをインストールしていきましょう。今回は以下の他にも使うライブラリがありそうで、必要に応じてインストールしてください。
pip install pyautogen==0.2.2
pip install langchain==0.0.352
はじめに、Agentから利用するデータベースを作成しましょう。とりあえず簡易的なものを作っていきます。生徒テーブル、科目テーブル、テストの点数のテーブルを作っておきます。
import sqlite3
connection = sqlite3.connect('sample.db')
cursor = connection.cursor()
create_student_table = """CREATE TABLE students (
id INTEGER,
name TEXT NOT NULL,
age INTEGER NOT NULL
);"""
cursor.execute(create_student_table)
create_courses_table = """CREATE TABLE courses (
id INTEGER,
course TEXT NOT NULL
);"""
cursor.execute(create_courses_table)
create_grades_table = """CREATE TABLE grades (
name INTEGER NOT NULL,
course TEXT NOT NULL,
grade FLOAT NOT NULL
);"""
cursor.execute(create_grades_table)
students = [(1, "Taro", 20), (2, "Hanako", 19), (3, "Jiro", 21)]
courses = [(1, "Math"), (2, "English"), (3, "Physics")]
grades = [
("Taro", "Math", 85.0),
("Taro", "English", 90.0),
("Taro", "Physics", 88.0),
("Hanako", "Math", 75.0),
("Hanako", "English", 80.0),
("Hanako", "Physics", 82.0),
("Jiro", "Math", 92.0),
("Jiro", "English", 95.0),
("Jiro", "Physics", 89.0),
]
insert_students = """INSERT INTO students (id, name, age) VALUES (?, ?, ?);"""
cursor.executemany(insert_students, students)
insert_courses = """INSERT INTO courses (id, course) VALUES (?, ?);"""
cursor.executemany(insert_courses, courses)
insert_grades = """INSERT INTO grades (name, course, grade) VALUES (?, ?, ?);"""
cursor.executemany(insert_grades, grades)
connection.commit()
cursor.close()
Functionsの作成
はじめにLangChainの仕組みを使いデータベースを読み込みます。LangChainの提供するSQLDatabaseToolkitを利用することで、自然言語でのSQLの実行がきます。
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.utilities.sql_database import SQLDatabase
os.environ["OPENAI_API_KEY"] = "<API Key>"
sql = SQLDatabase.from_uri('sqlite:///sample.db')
llm = OpenAI(model_name="gpt-3.5-turbo-instruct", temperature=0)
toolkit = SQLDatabaseToolkit(db=sql, llm=llm)
以下のようにして、作成したtoolkitを呼び出すためのFunction Callingのスキーマを作ってみましょう。toolkitのパラメータをgenerate_llm_config関数で変換しtool_schemaに保存します。また、LangChainのtoolkitを呼び出すのはtool._runであり、これをfunction_mapに登録します。
def generate_llm_config(tool):
# LangChainのAgentツールからFunction schemaを作成
function_schema = {
"name": tool.name.lower().replace (' ', '_'),
"description": tool.description,
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
}
if tool.args is not None:
function_schema["parameters"]["properties"] = tool.args
return function_schema
# Now use AutoGen with Langchain Tool Bridgre
tools = []
function_map = {}
for tool in toolkit.get_tools(): #debug_toolkit if you want to use tools directly
tool_schema = generate_llm_config(tool)
tools.append(tool_schema)
function_map[tool.name] = tool._run
Agentの作成
いつも通りAgentを作っていきましょう。複数の会話が連続する中で、タスクを終了した時にOpenAIのAgentがTERMINATEを返すのを読み込んで、UserProxyは動作を停止します。また、先ほど作成したfunction_mapを登録します。
import autogen
config_list = [{
'model': 'gpt-4-1106-preview',
'api_key': '<API Key>'
}]
# Construct the llm_config
llm_config = {
"functions": tools,
"config_list": config_list, # Assuming you have this defined elsewhere
"timeout": 120,
}
user_proxy = autogen.UserProxyAgent(
name="user_proxy",
is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"),
human_input_mode="NEVER",
max_consecutive_auto_reply=10,
code_execution_config={"work_dir": "tmp", "use_docker": False},
)
# Register the tool and start the conversation
user_proxy.register_function(
function_map = function_map
)
chatbot = autogen.AssistantAgent(
name="chatbot",
system_message="For coding tasks, only use the functions you have been provided with. Reply TERMINATE when the task is done.",
llm_config=llm_config,
)
実行結果(テーブルの概要)
まずはテーブルの内容の説明をしてもらいました。長いので結果のみ出力します。なお、SQLの場合利用できるツールは、sql_db_query、sql_db_schema、sql_db_list_tables、sql_db_query_checkerがありますが、今回はsql_db_list_tablesを利用してテーブルリストを抽出してからsql_db_schemaでテーブルの詳細を見ているようで、多段階での作業ができることを確認できました。
Q: どんなデータを持っていますか
A: 以下のスキーマとサンプルデータを持つ3つのテーブルがデータベースに存在します:
1. `courses` テーブル:
- `id`: コースID(整数)
- `course`: コース名(テキスト、null不可)
サンプルデータ:
```
| id | course |
|----|---------|
| 1 | Math |
| 2 | English |
| 3 | Physics |
```
2. `grades` テーブル:
- `name`: 学生名(整数、null不可)
- `course`: コース名(テキスト、null不可)
- `grade`: 成績(浮動小数点数、null不可)
サンプルデータ:
```
| name | course | grade |
|-------|---------|-------|
| Taro | Math | 85.0 |
| Taro | English | 90.0 |
| Taro | Physics | 88.0 |
```
3. `students` テーブル:
- `id`: 学生ID(整数)
- `name`: 学生名(テキスト、null不可)
- `age`: 年齢(整数、null不可)
サンプルデータ:
```
| id | name | age |
|----|--------|-----|
| 1 | Taro | 20 |
| 2 | Hanako | 19 |
| 3 | Jiro | 21 |
```
これらの情報を使って、さらにクエリを作成したり、データを分析することができます。何か特定の情報が必要であれば、質問してください。
実行結果(データベースへの数値の質問)
データベースに対して、生徒は何人ですか?と質問を投げてみましょう。
user_proxy.initiate_chat(
chatbot,
message="生徒は何人ですか?",
llm_config=llm_config,
)
結果は以下になりました。
現在、生徒は3人です。
続いて、微妙に複雑なクエリを投げてみましょう。
Q: 英語の点数が一番高い生徒は
A: The student with the highest English score is Taro.
終わりに
このようにFunction Callingを併用することで、LangChainの資産の流用ができます。今回の実験では、その仕組を利用してSQLを扱えるようにしてみました。DBは様々なシーンで使われている割に、特に非エンジニアの場合、クエリを書くのが苦手な人も多いので、こういった機能は有用ではないでしょうか。