AutoGenをガッツリ試してみる:SQLを利用したDB内データの可視化
概要
前回の記事では、LangChainのAgent Toolkitを利用してSQLに問い合わせるサンプルを作ってみました。ただし、LangChainでSQLクエリを実行するため結果がテキストとして出力されます。人数を聞くなどする分には問題ないのですが、出力された結果がテキストとしてプロンプトに含まれてしまうのでSQLデータを利用したデータをプロットしたり加工するのには向きません。本記事ではDBから取得したデータからグラフを作るサンプルを作成してみます。
準備
まずは、ソースとなるデータを準備しましょう。前回のような点数データだとデータ量が少ないので、より大きなデータを利用します。ここでは複数社の株価を利用してみます。ここでは、NVDA、TSLA、AAPL、MSFT、AMZNの5社を利用してみます。結果をdatabase/stock.dbというファイルに保存します。databaseディレクトリは、後のAgentのワークディレクトリとなります。
import os
import sqlite3
import yfinance as yf
import pandas as pd
from matplotlib import pyplot as plt
# Yahoo FinanceからNVIDIAとTeslaの株価データを取得します
nvidia = yf.download("NVDA", start="2021-01-01", end="2021-12-31")
tesla = yf.download("TSLA", start="2021-01-01", end="2021-12-31")
aapl = yf.download("AAPL", start="2021-01-01", end="2021-12-31")
msft = yf.download("MSFT", start="2021-01-01", end="2021-12-31")
amzn = yf.download("AMZN", start="2021-01-01", end="2021-12-31")
nvidia["Ticker"] = "NVDA"
tesla["Ticker"] = "TSLA"
aapl["Ticker"] = "AAPL"
msft["Ticker"] = "MSFT"
amzn["Ticker"] = "AMZN"
df = pd.concat((nvidia, tesla, aapl, msft, amzn)).sort_index()
df = df.reset_index()
try:
os.remove('database/stock.db')
except FileNotFoundError:
pass
connection = sqlite3.connect('database/stock.db')
df.to_sql('stock_price', connection, if_exists='replace')
Function Callingの準備
データベースを利用するにあたりテーブルのスキーマなどを知っておく必要があります。AutoGen側でソースコードを都度作成してスキーマをチェックしても良いですが、ある程度定形のロジックはFunction Callingにしておくと動作がロバストになります。LangChainにあるツールを利用しFunction Callingの形式にしてみましょう。
LangChainのSQLのツールには、テーブルリストの取得、テーブル情報の取得、クエリの作成、実行などがあります。クエリを実行する部分はAutoGenのコードに任せることで出力を直接分析ロジックに投入する事にしてここでは利用しません。テーブルリスト・テーブル情報の取得のツールとして、InfoSQLDatabaseTool、ListSQLDatabaseToolの2つを利用しましょう。下記ソースコードはLangChainのagent_toolkits/sql/toolkit.pyから持ってきました。
from langchain.utilities.sql_database import SQLDatabase
from langchain_community.tools import InfoSQLDatabaseTool
from langchain_community.tools import ListSQLDatabaseTool
db = SQLDatabase.from_uri('sqlite:///database/stock.db')
list_sql_database_tool = ListSQLDatabaseTool(db=db)
info_sql_database_tool_description = (
"Input to this tool is a comma-separated list of tables, output is the "
"schema and sample rows for those tables. "
"Be sure that the tables actually exist by calling "
f"{list_sql_database_tool.name} first! "
"Example Input: table1, table2, table3"
)
info_sql_database_tool = InfoSQLDatabaseTool(
db=db,
description=info_sql_database_tool_description
)
LangChainのツールの準備はできたので、これをAutoGenで利用できるようにしましょう。ここのプロセスは前回と同じです。
def generate_llm_config(tool):
# LangChainのAgentツールからFunction schemaを作成
function_schema = {
"name": tool.name.lower().replace (' ', '_'),
"description": tool.description,
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
}
if tool.args is not None:
function_schema["parameters"]["properties"] = tool.args
return function_schema
tools = []
function_map = {}
for tool in [list_sql_database_tool, info_sql_database_tool]:
tool_schema = generate_llm_config(tool)
tools.append(tool_schema)
function_map[tool.name] = tool._run
Agentの作成
前回はDBの扱いはすべてLangChainが担っていましたが、今回はクエリの実行はAutoGen側で行います。work_dir内のファイルはAutoGenで利用できるため、Agentのプロンプトでファイル名を教えてあげましょう。AutoGenのデフォルトのプロンプトを少しだけ改変しています。"Additional Information: You have a database named "stock.db" that contains information about stock prices."という文言を最後に追加しました。
また、おまけ程度ですが、出力した指示にSQLのコードブロックが作られてしまうこともあったので、"Include no more than one code block in a response."という指示を追加しています。あと、終了判定の文言を少し変えて"Please conclude with "TERMINATE" once you have successfully answered the user's instruction or question."としました。
DEFAULT_SYSTEM_MESSAGE = """You are a helpful AI assistant.
Solve tasks using your coding and language skills.
In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute.
1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself.
2. When you need to perform some task with code, use the code to perform the task and output the result. Finish the task smartly.
Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill.
When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user. Include no more than one code block in a response.
If you want the user to save the code in a file before executing it, put # filename: <filename> inside the code block as the first line. Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.
If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.
When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.
Please conclude with "TERMINATE" once you have successfully answered the user's instruction or question.
Additional Information: You have a database named "stock.db" that contains information about stock prices.
"""
続いてAgentを作っていきます。Function Callingの設定は前回と同じ。AssistantAgentのシステムプロンプトを上記で上書きしています。
from autogen import AssistantAgent, UserProxyAgent, config_list_from_json
code_execution_config = {
"work_dir": "database",
"use_docker": False
}
config_list = [{
'model': 'gpt-4-1106-preview', 'api_key': '<API-Key>'
}]
assistant = AssistantAgent(
"assistant",
system_message = DEFAULT_SYSTEM_MESSAGE,
llm_config={"config_list": config_list, "functions": tools}
)
user_proxy = UserProxyAgent(
"user_proxy",
human_input_mode="NEVER",
max_consecutive_auto_reply=10,
is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"),
code_execution_config=code_execution_config
)
# Register the tool and start the conversation
user_proxy.register_function(
function_map = function_map
)
実験
まずは、データベース内にあるティッカーのリストを出力してもらいます。
message = "ティッカーのリストを出力してください"
user_proxy.initiate_chat(assistant, message=message)
結果としては以下のようになりました。少し長いですが全文を載せておきます。きちんとFunction Calling経由でデータベースの中身をチェックしていることが確認できます。チェックしたフォーマットに基づいて、プログラムでティッカーのリストを取得できていますね。
user_proxy (to assistant):
ティッカーのリストを出力してください
--------------------------------------------------------------------------------
assistant (to user_proxy):
***** Suggested function Call: sql_db_list_tables *****
Arguments:
{}
*******************************************************
--------------------------------------------------------------------------------
>>>>>>>> EXECUTING FUNCTION sql_db_list_tables...
user_proxy (to assistant):
***** Response from calling function "sql_db_list_tables" *****
stock_price
***************************************************************
--------------------------------------------------------------------------------
assistant (to user_proxy):
***** Suggested function Call: sql_db_schema *****
Arguments:
{"table_names":"stock_price"}
**************************************************
--------------------------------------------------------------------------------
>>>>>>>> EXECUTING FUNCTION sql_db_schema...
user_proxy (to assistant):
***** Response from calling function "sql_db_schema" *****
CREATE TABLE stock_price (
"index" INTEGER,
"Date" TIMESTAMP,
"Open" REAL,
"High" REAL,
"Low" REAL,
"Close" REAL,
"Adj Close" REAL,
"Volume" INTEGER,
"Ticker" TEXT
)
/*
3 rows from stock_price table:
index Date Open High Low Close Adj Close Volume Ticker
0 2021-01-04 00:00:00 131.04249572753906 136.52499389648438 129.625 131.13499450683594 130.84011840820312 56064000 NVDA
1 2021-01-04 00:00:00 133.52000427246094 133.61000061035156 126.76000213623047 129.41000366210938 127.1641845703125 143301900 AAPL
2 2021-01-04 00:00:00 222.52999877929688 223.0 214.80999755859375 217.69000244140625 211.99661254882812 37130100 MSFT
*/
**********************************************************
--------------------------------------------------------------------------------
assistant (to user_proxy):
The `stock_price` table contains a "Ticker" column that represents the ticker symbols of stock prices. To output a list of distinct tickers, we can run a SQL query to select the unique ticker values from this `stock_price` table.
I will provide you with the Python code to run the SQL query and output the list of tickers. Please ensure you have a database connection library like `sqlite3` installed in your Python environment.
```python
# filename: list_tickers.py
import sqlite3
# Connect to the SQLite database
conn = sqlite3.connect('stock.db')
# Create a cursor object using the cursor() method
cursor = conn.cursor()
# Query to find all unique tickers
query = 'SELECT DISTINCT Ticker FROM stock_price ORDER BY Ticker'
# Execute the SQL query
cursor.execute(query)
# Fetch all the results
tickers = cursor.fetchall()
# Close the database connection
conn.close()
# Print the list of tickers
for ticker in tickers:
print(ticker[0])
```
Execute this script to get the list of distinct stock tickers from your "stock.db" database.
--------------------------------------------------------------------------------
>>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...
user_proxy (to assistant):
exitcode: 0 (execution succeeded)
Code output:
AAPL
AMZN
MSFT
NVDA
TSLA
--------------------------------------------------------------------------------
assistant (to user_proxy):
The list of distinct stock tickers from the "stock.db" database is:
- AAPL
- AMZN
- MSFT
- NVDA
- TSLA
TERMINATE
--------------------------------------------------------------------------------
上記までであれば前回の機能で良いですが、取得したデータをプログラムで利用することで今回はグラフにしていきましょう。
message = "NVDAとTSLAの株価の変化をプロットしstock.pngという名前で保存してください"
user_proxy.initiate_chat(assistant, message=message)
上記を実行したら、先程のようにテーブルの中身を確認してからプログラムを作成・実行して以下のようなプロット結果を得られました。
おわりに
データベースから得られる情報を分析に利用するために、テーブル情報の取得のみはFunction Calling、クエリの実行と結果の利用をAutoGenで生成したプログラムから行う例を作りました。単にDBをプログラムから読み込むだけではなく、定形処理は一部Function Callingを利用するなどの工夫などもしてみました。次回はAutoGenのソースコードでも読んでみようかな。