見出し画像

LangChainを使って自然言語でRDBからデータを取得する

はじめに

こんにちは。CRSチームの小澤です。

先日LangChainのブログで、LangChainを使って自然言語でRDBに問い合わせるチュートリアルが紹介されていました。

自然言語を使ってRDBからデータを取得できるようになると、SQLに精通していない方でもデータを手軽に扱えるようになって、より多くの人がデータを活用できるようになるかもしれません。

今回はこちらを検証していきます。

元記事ではMySQLを使って、音楽関連のデータを使っていますが、今回はSQLiteを使って化学系のデータを使用して検証してみます。
また、元記事ではgpt-3.5-turboを使っているようですが、本記事ではgpt-4を使用します。
またコードの実行はGoogle Colabで行います。

データの取得

まずはデータを取得しましょう。

PubChemから化合物を20件取得して、名前、分子量、SMILESを2つのテーブルに格納します。
2つのテーブルに分けて入れるのは後ほどテーブル結合を行えるか確認するためです。

テーブル名はcompoundsとpropertiesで、下記のような構造になります。
テスト用途のため簡易な構造にしています。

$$
\text{compoundsテーブル} \\
\begin{array}{|l|l|l|} \hline
\text{カラム名} & \text{説明} \\ \hline
\text{compound\_id} & \text{化合物のユニークなID (主キー)} \\ \hline
\text{name} & \text{化合物の名前} \\ \hline
\end{array}
$$

$$
\text{propertiesテーブル} \\
\begin{array}{|l|l|l|} \hline
\text{カラム名} & \text{説明} \\ \hline
\text{propert\_id} & \text{物性情報のユニークなID (主キー)} \\ \hline
\text{compound\_id} & \text{化合物のID (外部キー)} \\ \hline
\text{molecular\_weight} & \text{分子量} \\ \hline
\text{canonical\_smiles} & \text{SMILES表記} \\ \hline
\end{array}
$$

下記コードでPubChemから化合物のデータを取得して、テーブルに格納します。

!pip install pubchempy

import pubchempy as pcp
import csv
import sqlite3
import time

# 化合物名のリスト
compounds = ['Aspirin', 'Glucose', 'Caffeine', 'Ethanol', 'Acetaminophen', 'Ibuprofen', 'Sucrose', 'Glycerol', 'Acetic acid', 'Sodium chloride', 'Benzoic acid', 'Ascorbic acid', 'Citric acid', 'Methanol', 'Isopropyl alcohol', 'Ammonia', 'Urea', 'Phenol', 'Formaldehyde', 'Acetone']

# CSVファイルを開く
with open('compounds.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['compound_id', 'name'])
    
    # 各化合物の情報を取得し、CSVファイルに書き込む
    for compound_name in compounds:
        compound = pcp.get_compounds(compound_name, 'name')[0]
        writer.writerow([compound.cid, compound.iupac_name])
        time.sleep(0.2) # リクエストの間に0.2秒の遅延を追加

# 物性情報のCSVファイルを作成
with open('properties.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['property_id', 'compound_id', 'molecular_weight', 'canonical_smiles'])
    
    # 各化合物の物性情報を取得し、CSVファイルに書き込む
    property_id = 1
    for compound_name in compounds:
        compound = pcp.get_compounds(compound_name, 'name')[0]
        molecular_weight = compound.molecular_weight
        canonical_smiles = compound.canonical_smiles
        writer.writerow([property_id, compound.cid, molecular_weight, canonical_smiles])
        property_id += 1
        time.sleep(0.2) # リクエストの間に0.2秒の遅延を追加

# SQLiteデータベースに接続
conn = sqlite3.connect('compounds.db')
c = conn.cursor()

# テーブルを作成
c.execute('''CREATE TABLE IF NOT EXISTS compounds
             (compound_id INTEGER PRIMARY KEY,
             name TEXT)''')

c.execute('''CREATE TABLE IF NOT EXISTS properties
             (property_id INTEGER PRIMARY KEY,
             compound_id INTEGER,
             molecular_weight REAL,
             canonical_smiles TEXT,
             FOREIGN KEY (compound_id) REFERENCES compounds (compound_id))''')

# CSVファイルからデータをインポート
with open('compounds.csv', 'r') as file:
    reader = csv.reader(file)
    next(reader) # ヘッダー行をスキップ
    for row in reader:
        c.execute("INSERT INTO compounds VALUES (?, ?)", row)

with open('properties.csv', 'r') as file:
    reader = csv.reader(file)
    next(reader) # ヘッダー行をスキップ
    for row in reader:
        c.execute("INSERT INTO properties VALUES (?, ?, ?, ?)", row)

conn.commit()

データが正しく入ったか確認してみましょう

%load_ext sql

# DBに接続
%sql sqlite:///compounds.db
%sql select * from compounds;

sqlite:///compound.db
 * sqlite:///compounds.db
Done.

compound_id	name
176	acetic acid
180	propan-2-one
222	azane
243	benzoic acid
311	2-hydroxypropane-1,2,3-tricarboxylic acid
702	ethanol
712	formaldehyde
753	propane-1,2,3-triol
887	methanol
996	phenol
1176	urea
1983	N-(4-hydroxyphenyl)acetamide
2244	2-acetyloxybenzoic acid
2519	1,3,7-trimethylpurine-2,6-dione
3672	2-[4-(2-methylpropyl)phenyl]propanoic acid
3776	propan-2-ol
5234	sodium;chloride
5793	(3R,4S,5S,6R)-6-(hydroxymethyl)oxane-2,3,4,5-tetrol
5988	(2R,3R,4S,5S,6R)-2-[(2S,3S,4S,5R)-3,4-dihydroxy-2,5-bis(hydroxymethyl)oxolan-2-yl]oxy-6-(hydroxymethyl)oxane-3,4,5-triol
54670067	(2R)-2-[(1S)-1,2-dihydroxyethyl]-3,4-dihydroxy-2H-furan-5-one
%sql select * from properties

sqlite:///compound.db
 * sqlite:///compounds.db
Done.

property_id	compound_id	molecular_weight	canonical_smiles
1	2244	180.16	CC(=O)OC1=CC=CC=C1C(=O)O
2	5793	180.16	C(C1C(C(C(C(O1)O)O)O)O)O
3	2519	194.19	CN1C=NC2=C1C(=O)N(C(=O)N2C)C
4	702	46.07	CCO
5	1983	151.16	CC(=O)NC1=CC=C(C=C1)O
6	3672	206.28	CC(C)CC1=CC=C(C=C1)C(C)C(=O)O
7	5988	342.3	C(C1C(C(C(C(O1)OC2(C(C(C(O2)CO)O)O)CO)O)O)O)O
8	753	92.09	C(C(CO)O)O
9	176	60.05	CC(=O)O
10	5234	58.44	[Na+].[Cl-]
11	243	122.12	C1=CC=C(C=C1)C(=O)O
12	54670067	176.12	C(C(C1C(=C(C(=O)O1)O)O)O)O
13	311	192.12	C(C(=O)O)C(CC(=O)O)(C(=O)O)O
14	887	32.042	CO
15	3776	60.1	CC(C)O
16	222	17.031	N
17	1176	60.056	C(=O)(N)N
18	996	94.11	C1=CC=C(C=C1)O
19	712	30.026	C=O
20	180	58.08	CC(=O)C

正しく格納されているようです

LangChainを活用した自然言語問い合わせパイプラインの構築

続いてLangChainをインストールして、自然言語で問い合わせる準備をします。

from langchain_core.prompts import ChatPromptTemplate
from langchain_community.utilities import SQLDatabase

template = """
Based on the table schema below, write a SQL query that would answer the user's question.
{schema}

Question: {question}
SQL Query:
"""

prompt = ChatPromptTemplate.from_template(template)

db_uri = "sqlite:///compounds.db"
db = SQLDatabase.from_uri(db_uri)

def get_schema(_):
    return db.get_table_info()

from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import AzureChatOpenAI

llm = AzureChatOpenAI(
    model="gpt-4",
    deployment_name="chat-gpt-4",
    api_key=<API-KEY>,
    azure_endpoint=<AZURE-END-POINT>,
    api_version="2024-02-15-preview",
)

sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop="\nSQL Result:")
    | StrOutputParser()
)

sql_chainは自然言語の質問からSQLクエリを生成します。
get_schema関数からデータベースのスキーマ情報を取得して、質問文と合わせてpromptに埋め込んで、それをllmに渡しています。

また、llm.bind(stop="\nSQL Result:")の部分は"\nSQL Result:"をストップシーケンスとして指定することで、SQLクエリ以外の不要な文章が生成されるのを防いでいるようです。
ただ、今回はgpt-4を使ったためか、SQLクエリ以外の文章が表示されることは十数回試した感じでは無さそうでした。
ストップシーケンスがついていても影響は無さそうなので、元記事のコードのまま進めてみます。

試してみましょう。

sql_chain.invoke({"question": "化合物はいくつありますか?"})
SELECT COUNT(*) FROM compounds;

化合物の数を取得するためのSELECT COUNT(*)を使ったSQLクエリが生成されました。

続いてこのSQLでRDBへ問い合わせて、その結果を生成するように実装します。
日本語で回答を生成したいので、元記事からプロンプトを変更しています。
また、検証のために使用したSQLも回答に含めるように指示を入れています。

template = """
以下の table schema, question, sql query, sql responseに基づいて、
日本語での返答を書いてください。回答には使用したSQLクエリも含めてください:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""

prompt = ChatPromptTemplate.from_template(template)
def run_query(query):
    return db.run(query)

full_chain = (
  RunnablePassthrough.assign(query=sql_chain).assign(
    schema=get_schema,
    response= lambda vars: run_query(vars["query"])
  )
  | prompt
  | llm
  | StrOutputParser()
)

full_chainではまずsql_chainに問い合わせてSQLを生成して、それを用いてRDBに問い合わせます。
最後にその結果を用いて、自然言語での回答を生成します。

これで自然言語でRDBのデータを取得できるようになりました。
試しに先程と同じ質問をしてみます。

full_chain.invoke({"question": "化合物はいくつありますか?"})

化合物は20種類あります。
使用したSQLクエリ: SELECT COUNT(*) FROM compounds;

 回答と使用したSQLクエリが表示されています。
化合物の個数は正しく20個と表示されました。

ここまでの処理のまとめ

一旦処理の流れをまとめてみましょう。
まず、入力した自然言語からSQLへ変換するために、gpt-4へ下記のプロンプトが送られます。

Human: Based on the table schema below, write a SQL query that would answer the user's question.

CREATE TABLE compounds ( compound_id INTEGER, name TEXT, PRIMARY KEY (compound_id) )

/* 3 rows from compounds table: compound_id name 176 acetic acid 180 propan-2-one 222 azane */

CREATE TABLE properties ( property_id INTEGER, compound_id INTEGER, molecular_weight REAL, canonical_smiles TEXT, PRIMARY KEY (property_id), FOREIGN KEY(compound_id) REFERENCES compounds (compound_id) )

/* 3 rows from properties table: property_id compound_id molecular_weight canonical_smiles 1 2244 180.16 CC(=O)OC1=CC=CC=C1C(=O)O 2 5793 180.16 C(C1C(C(C(C(O1)O)O)O)O)O 3 2519 194.19 CN1C=NC2=C1C(=O)N(C(=O)N2C)C */

Question: 化合物はいくつありますか? SQL Query:

データベースの情報はスキーマだけでなく数個のサンプルデータも渡されています。LangChainのdb.get_table_info()から取得された情報です。

この質問に対する回答が以下になります。

SELECT COUNT(*) FROM compounds;

SQL文のみが返ってきました。

続いてこのSQLをLangChainがRDBに対して実行します。
そこで受け取った結果を使って再度gpt-4に下記プロンプトで質問を送ります。

Human:
以下の table schema, question, sql query, sql responseに基づいて、日本語での返答を書いてください。回答には使用したSQLクエリも含めてください:

CREATE TABLE compounds (
compound_id INTEGER,
name TEXT,
PRIMARY KEY (compound_id)
)

/*
3 rows from compounds table:
compound_id name
176 acetic acid
180 propan-2-one
222 azane
*/

CREATE TABLE properties (
property_id INTEGER,
compound_id INTEGER,
molecular_weight REAL,
canonical_smiles TEXT,
PRIMARY KEY (property_id),
FOREIGN KEY(compound_id) REFERENCES compounds (compound_id)
)

/*
3 rows from properties table:
property_id compound_id molecular_weight canonical_smiles
1 2244 180.16 CC(=O)OC1=CC=CC=C1C(=O)O
2 5793 180.16 C(C1C(C(C(C(O1)O)O)O)O)O
3 2519 194.19 CN1C=NC2=C1C(=O)N(C(=O)N2C)C
*/

Question: 化合物はいくつありますか?
SQL Query: SELECT COUNT(*) FROM compounds;
SQL Response: [(20,)]

最後のところに使用したSQLとそのレスポンスが追加されています。

この質問に対するgpt-4の回答が下記になります。

化合物は合計20個あります。

使用したSQLクエリ: SELECT COUNT(*) FROM compounds;

この回答が最終的な回答として使用されています。

GradioによるUIの追加

ここから色々と条件を変えて試して行きますが、Colabで毎回コードを書くのはやや大変なのと出力が見づらいため、UI上で実行できるようにしてみます。

今回はGradioを使います。
GradioはLLMを使ったアプリケーションを簡単に作成できるPythonライブラリで、Colab上にもUIを表示できます。

GradioをインストールしてColab上にUIを表示してみます。

!pip install gradio
import gradio as gr

def generate(question):
    return full_chain.invoke({"question": question})

app = gr.Interface(fn=generate, inputs="text", outputs="text")

app.launch()

このコードを実行するとColab上にGradioのUIが表示されました。
questionに質問を入れてSubmitボタンを押すとoutputに結果が表示されます。

かなり便利になりました。

なお、右下にあるFlagボタンを押すと入力と出力がログファイルに保存されます。
生成結果を保存しておきたいときに便利な機能です。

以降はこのUIを使って確認していきます。

自然言語による問い合わせ

条件と並び順

取得に少し条件をつけてみましょう。
文字列の部分一致と順序指定を使えるかのテストです。

question:
名前にacidが付く化合物をcidの順番で並べて教えて

answer:
SQLクエリにより、名前に'acid'が含まれる化合物が以下の順番で取得されました。
1. 化合物ID:176、名前:酢酸(acetic acid)
2. 化合物ID:243、名前:安息香酸(benzoic acid)
3. 化合物ID:311、名前:2-ヒドロキシプロパン-1,2,3-トリカルボン酸(2-hydroxypropane-1,2,3-tricarboxylic acid)
4. 化合物ID:2244、名前:2-アセチルオキシ安息香酸(2-acetyloxybenzoic acid)
5. 化合物ID:3672、名前:2-[4-(2-メチルプロピル)フェニル]プロパノ酸(2-[4-(2-methylpropyl)phenyl]propanoic acid)

使用したSQLクエリは以下の通りです:
```
SELECT c.compound_id, c.name
FROM compounds c
WHERE c.name LIKE '%acid%' ORDER BY c.compound_id;
```

SQLのLIKEもORDERも正しく生成されて、回答も問題ありません。

テーブルの結合

続いて2つのテーブルから名前とSMILESを取得する、テーブル結合が必要な質問をしてみます。

question: 化合物の名前とSMILESを教えて

answer: 以下は、化合物の名前とSMILESの一覧です:
- 2-acetyloxybenzoic acid : CC(=O)OC1=CC=CC=C1C(=O)O
- (3R,4S,5S,6R)-6-(hydroxymethyl)oxane-2,3,4,5-tetrol : C(C1C(C(C(C(O1)O)O)O)O)O
- 1,3,7-trimethylpurine-2,6-dione : CN1C=NC2=C1C(=O)N(C(=O)N2C)C
- ethanol : CCO
- N-(4-hydroxyphenyl)acetamide : CC(=O)NC1=CC=C(C=C1)O
- 2-[4-(2-methylpropyl)phenyl]propanoic acid : CC(C)CC1=CC=C(C=C1)C(C)C(=O)O
- (2R,3R,4S,5S,6R)-2-[(2S,3S,4S,5R)-3,4-dihydroxy-2,5-bis(hydroxymethyl)oxolan-2-yl]oxy-6-(hydroxymethyl)oxane-3,4,5-triol : C(C1C(C(C(C(O1)OC2(C(C(C(O2)CO)O)O)CO)O)O)O)O
- propane-1,2,3-triol : C(C(CO)O)O
- acetic acid : CC(=O)O
- sodium;chloride : [Na+].[Cl-]
- benzoic acid : C1=CC=C(C=C1)C(=O)O
- (2R)-2-[(1S)-1,2-dihydroxyethyl]-3,4-dihydroxy-2H-furan-5-one : C(C(C1C(=C(C(=O)O1)O)O)O)O
- 2-hydroxypropane-1,2,3-tricarboxylic acid : C(C(=O)O)C(CC(=O)O)(C(=O)O)O
- methanol : CO
- propan-2-ol : CC(C)O
- azane : N
- urea : C(=O)(N)N
- phenol : C1=CC=C(C=C1)O
- formaldehyde : C=O
- propan-2-one : CC(=O)C

SQL Query:
SELECT compounds.name, properties.canonical_smiles
FROM compounds
JOIN properties ON compounds.compound_id = properties.compound_id;

JOINを使ったSQLを生成できています。

サブクエリ

次は少し複雑なサブクエリが必要な質問をしてみます。
分子量がデータ内の平均値(約117.64)以上の化合物を取得してみます。

question: 分子量がこのデータの中の平均値より大きい化合物名と分子量を教えて。分子量が小さい順に並び替えて

answer:
SQLクエリ:
SELECT c.name, p.molecular_weight
FROM compounds c
JOIN properties p ON c.compound_id = p.compound_id
WHERE p.molecular_weight >
(SELECT AVG(molecular_weight) FROM properties)
ORDER BY p.molecular_weight ASC;

このクエリにより、分子量が平均より大きい化合物名とその分子量が小さい順に次のように得られました:
- 'benzoic acid'、分子量 122.12
- 'N-(4-hydroxyphenyl)acetamide'、分子量 151.16
- '(2R)-2-[(1S)-1,2-dihydroxyethyl]-3,4-dihydroxy-2H-furan-5-one'、分子量 176.12
- '2-acetyloxybenzoic acid'、分子量 180.16
- '(3R,4S,5S,6R)-6-(hydroxymethyl)oxane-2,3,4,5-tetrol'、分子量 180.16
- '2-hydroxypropane-1,2,3-tricarboxylic acid'、分子量 192.12
- '1,3,7-trimethylpurine-2,6-dione'、分子量 194.19
- '2-[4-(2-methylpropyl)phenyl]propanoic acid'、分子量 206.28
- '(2R,3R,4S,5S,6R)-2-[(2S,3S,4S,5R)-3,4-dihydroxy-2,5-bis(hydroxymethyl)oxolan-2-yl]oxy-6-(hydroxymethyl)oxane-3,4,5-triol'、分子量 342.3 

正しくWHEREでサブクエリを使えて、平均値以上の化合物を取得できています。

条件ごとの集計

続いて条件ごとの集計が必要な質問を投げてみます。
分子量が100以下の化合物の個数と、それ以外の化合物の個数を聞いてみます。

question: 化合物の分子量が100以下とそれ以外に分けてそれぞれの個数を教えて

answer: SQLクエリにより、分子量が100以下の化合物は11個、100以上の化合物は9個であることが分かりました。

使用したSQLクエリ:
```
SELECT
CASE WHEN molecular_weight <= 100 THEN '100 or less'
ELSE 'more than 100' END AS molecular_weight_group,
COUNT(*) AS number_of_compounds
FROM properties
GROUP BY molecular_weight_group;
```

CASE WHENとGROUP BYを使って正しく個数を表示できました。
素晴らしいですね!

データの削除

最後にデータの削除を試してみます。
実運用では参照系のみの権限を与えて、更新系の権限は与えないほうが安全かと思いますが、今回はテストとして実行してみます。

question: メタノールのレコードを削除して

結果は残念ながらエラーになりました。

2つのテーブルからデータを削除するため、2つのDELETE文が生成され、複数のSQLが実行できずにエラーとなったようです。

"methanolの関連データをpropertiesテーブルから削除して"
とテーブルを指定して削除を依頼すると、今度は問題なく削除できました。

不整合レコードの取得

1つのテーブルからのみレコードを削除して整合性が失われた状態になりました。
折角なので整合性が取れていないデータがあるかを聞いてみます。

question: 分子量のデータが無い化合物を教えて

answer:
SQLクエリによると、分子量のデータが無い化合物は「メタノール」です。

使用したSQLクエリ:
```
SELECT c.name
FROM compounds c
LEFT JOIN properties p ON c.compound_id = p.compound_id
WHERE p.molecular_weight IS NULL;
```

LEFT JOINを使ってpropertiesテーブルに無いデータを取得できています。
不整合チェックも簡単にできますね!

まとめ

今回はLangChainを使った自然言語でのRDBの操作について検証してみました。
ほとんどのケースで正常にデータが取れたのは驚きでした。

今回のケースはかなり単純化したスキーマで試しましたが、実際のデータでも十分使える可能性はありそうです。
実際の業務への適用を考えると、参照系のクエリのみに限定して使うのが賢明だと思われます。
また、得られた結果が100%正しいとは限らないことも念頭に置く必要があります。

とはいえ、SQLを知らないユーザーでも簡単にデータにアクセスできるのは非常に魅力的です。ユーザーインターフェースと組み合わせることで、新しいデータ活用の形が生まれるかもしれません。


CRSチームでは今後も化学の研究活動のサポートに向けて調査・開発を進めて行きます。もしお困りのことがありましたらお気軽に下記までお問い合わせください!