LangChain Agentsを使ってテストコードから「テストの通るコード」を自動生成するプログラムを書いてみた
先日以下のような記事を書いてみたものの、いちいち結果をChatGPTに手でコピペしながら検証するのはダサいなと思っていました。
そういうわけでRSpecが通るまで愚直に検証&生成を繰り返すようなコードを書いてみたものの、修正履歴までChatGPTのコンテキストに持たせようとすると、すぐに最大トークン数を超えてしまい失敗してしまいます。
最大トークン数を超えないように頭の良いコンテキストを持たせるような実装も可能だとは思いますが、結構複雑な実装になってしまいそうです。
そんな中で出会ったのがこのツイートでした。
確かにLangChain Agentsを使えば、LLMで生成したコードを実行して確認するといった一連の行動を自動化できそうです。
時間が無い人向けの動画
こんな感じに動作するものをつくりました。
LangChain Agentsとは?
LangChain Chains
LangChainではChainsという指定した順序でLLMを連続実行できる仕組みがあります。LangChainの名の通り象徴的な機能です。AgentsはChainsの発展系なので、まずはChainsの動きから。
以下はドキュメントのサンプルにある、「colorful socks」を製造するような会社の名前として良さそうなものを推論させた上で、そんな名前の会社がキャッチフレーズにしそうな言葉を推論させるコードです。
from langchain.chat_models import ChatOpenAI
from langchain.chains import SimpleSequentialChain
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
# 会社名を推論
human_message_prompt = HumanMessagePromptTemplate(
prompt=PromptTemplate(
template="What is a good name for a company that makes {product}?",
input_variables=["product"],
)
)
chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
chat = ChatOpenAI(temperature=0.9)
chain = LLMChain(llm=chat, prompt=chat_prompt_template)
# キャッチフレーズを推論
second_prompt = PromptTemplate(
input_variables=["company_name"],
template="Write a catchphrase for the following company: {company_name}",
)
chain_two = LLMChain(llm=llm, prompt=second_prompt)
# 推論をチェーンで組み合わせる
overall_chain = SimpleSequentialChain(chains=[chain, chain_two], verbose=True)
# 実行
catchphrase = overall_chain.run("colorful socks")
print(catchphrase)
結果は以下のようになります。
連続して推論が実行されていることが分かります。
> Entering new SimpleSequentialChain chain...
Cheerful Toes.
"Spread smiles from your toes!"
> Finished SimpleSequentialChain chain.
"Spread smiles from your toes!"
LangChain Agents
対してAgentsは、利用して良いツールをエージェントに与えておくと、指定されたプロンプトを解決するためにLLMが自律的に道具を選定し、問題解決をしてくれる仕組みです。
例えば以下はGoogle検索を行うAPIと計算を行うツールを組み合わせた例です。
from langchain.agents import load_tools
from langchain.agents import initialize_agent
from langchain.llms import OpenAIChat
llm = OpenAIChat(temperature=0)
tools = load_tools(["serpapi", "llm-math"], llm=llm)
agent = initialize_agent(
tools,
llm,
agent="zero-shot-react-description",
verbose=True
)
agent.run("日本で一番長い川の長さを調べ、その長さを三乗してください")
実行結果は以下の通りです。
信濃川を検索した結果を踏まえて計算機で計算していることが分かります。
> Entering new AgentExecutor chain...
I don't speak Japanese, so I should use a search engine to translate the question and find the answer.
Action: Search
Action Input: "What is the length of the longest river in Japan and what is its cubed length?"
Observation: The 11,900 square kilometres (4,600 sq mi) basin of the Shinano-Chikuma River system is the third largest in Japan and at 367 kilometres (228 mi), is the longest river in the country.
Thought:Now that I have the answer, I need to use a calculator to find the cubed length.
Action: Calculator
Action Input: 367^3
Observation: Answer: 49430863
Thought:I now know the final answer
Final Answer: 49430863
> Finished chain.
'49430863'
全く同じプロンプトをChatGPTに与えてみると、信濃川の長さは学習データから導き出せているものの、計算結果が誤っています。
日本で一番長い川は信濃川で、長さは367キロメートルです。 したがって、信濃川の長さを三乗すると、約48,541,463キロメートルになります。
計算式: (367 km)^3 = 367 km × 367 km × 367 km = 48,541,463 km
生成したコードを自動検証する仕組みをAgentsで作成する
前回は点字生成プログラムを題材にしていましたが、今回はもう少し仕様をシンプルにして、逆ポーランド記法で与えられた計算式を計算する「逆ポーランド計算機」を題材にしてみようと思います。
逆ポーランド記法とは以下のような表記法のことです↓
「a+b」なら「ab+」と表記する
「a*(b+c)」なら「abc+*」と表記する
「(a+b)/(c-d)」なら「ab+cd-/」と表記する
この表記法のメリットは、カッコによる演算順の優先度の考慮なしに左から順に数字か記号をスタックに積んで行けば、スタックをポップし続けるだけで計算できることにあります。
テストコードを生成する
そんなわけでまずはテストコードを生成します。
今回はテストコードをChatGPTで生成した上で、このテストコードをパスするまでLangChainにコードを生成し続けてもらうことにします。
あなたは優秀なRubyプログラマーです。
逆ポーランド記法で表現された文字列を受け取り、計算を実行するプログラムのテストコードをRSpecで作成しなさい。
出力はRubyコードのみとし、それ以外の余計な文字は出力しないでください。
テストには正常系と異常系を含めるようにしてください。
以下が生成されたRSpecのコードです。
require_relative 'rpn_calculator'
describe RPNCalculator do
describe "#evaluate" do
let(:calculator) { RPNCalculator.new }
context "with a valid expression" do
it "evaluates addition correctly" do
expect(calculator.evaluate("5 3 +")).to eq(8)
end
it "evaluates subtraction correctly" do
expect(calculator.evaluate("5 3 -")).to eq(2)
end
it "evaluates multiplication correctly" do
expect(calculator.evaluate("5 3 *")).to eq(15)
end
it "evaluates division correctly" do
expect(calculator.evaluate("10 2 /")).to eq(5)
end
it "evaluates complex expressions correctly" do
expect(calculator.evaluate("5 3 + 4 * 7 -")).to eq(25)
end
end
context "with an invalid expression" do
it "raises an error for division by zero" do
expect { calculator.evaluate("5 0 /") }.to raise_error(ZeroDivisionError)
end
it "raises an error if the input does not follow the rules" do
expect { calculator.evaluate("5 + 3") }.to raise_error(ArgumentError)
end
end
end
end
RSpecを実行して生成コードを検証するカスタムツールを作成する
コード生成結果に対してRSpecを実行する仕組みはLangChainに用意されていないので、独自にツールを定義します。
from langchain.agents.tools import Tool
import traceback
import subprocess
class_name = "rpn_calculator"
class RSpecExecutor:
def __init__(self):
pass
def run(self, ruby_code):
try:
with open(f"{class_name}.rb", 'w') as f:
f.write(trim(ruby_code))
result = subprocess.check_output(['rspec', f"{class_name}_spec.rb"])
output = result.decode('utf-8')
except subprocess.CalledProcessError as e:
output = e.output.decode('utf-8')
except Exception:
output = traceback.format_exc()
return output
rspec_exec = Tool(
"RSpec Executor",
RSpecExecutor().run,
"A tool to test generated ruby code. Input should be a valid generated ruby code."
)
本筋とはズレますが、どうプロンプトを工夫しても空の改行や「```」で囲まれた結果を返すことが多かったため、LLMが返した結果をトリムする関数を以下のように定義しています。
def trim(string):
lines = string.split('\n')
while lines and (lines[0].strip() == '' or lines[0].startswith('`')):
lines = lines[1:]
while lines and (lines[-1].strip() == '' or lines[-1].startswith('`')):
lines = lines[:-1]
return '\n'.join(lines)
エージェントはToolの第三引数に指定する解説を理解して道具を選定するため、どういうときに使うべきツールなのかエージェントが理解できるように明確に書いておく必要があります。
LangChainで実行するコードを作成する
あとはAgentsを利用して実行するコードを書くだけです。
コードは下記の通りで、とてもシンプルです。
from langchain.llms import OpenAIChat
from langchain.agents import initialize_agent
llm = OpenAIChat(temperature=0.0)
tools = [rspec_exec]
agent = initialize_agent(tools, llm, agent="zero-shot-react-description", verbose=True)
with open(f"{class_name}_spec.rb", 'r') as f:
rspec_code = f.read()
prompt = f"""
You are an excellent professional Ruby programmer.
Create a Ruby code that meets the following test code specifications
code that meets the following test code specifications.
- Do not modify the test code.
- No test code should be output.
- When creating ruby code, please make sure to output only ruby code.
- Make sure that there are no syntax errors in the generated ruby code.
- To check whether a given input is a processable input, create a separate validate function and check it.
- Validate function must be written in the class.
# test code:
{rspec_code}
"""
agent.run(prompt)
実のところ、シンプルなプロンプトだけでは実装が明後日の方向へ行ってしまうことが多かったため、「クラス内にバリデーションのための関数を作れ」などと、実装の細かいところまで口出しするプロンプトを書いてしまっています。
また、エージェントは日本語を理解できないので(日本語を与えると英語に翻訳した上で処理を継続します)、プロンプトは英語で指定するようにしています。
実行結果
コードを実行してみると、まずは1回目の生成結果をもとにRSpec Executorを使ってテストを実行しようとします。
> Entering new AgentExecutor chain...
I need to create a Ruby code that meets the specifications of the test code.
Action: RSpec Executor
Action Input:
```
class RPNCalculator
def validate(input)
input_array = input.split(" ")
operators = ["+", "-", "*", "/"]
input_array.each do |char|
if !operators.include?(char) && !char.match?(/\d+/)
return false
end
end
return true
end
def evaluate(input)
raise ArgumentError unless validate(input)
stack = []
input.split.each do |char|
case char
when "+"
stack << stack.pop(2).reduce(:+)
when "-"
stack << stack.pop(2).reduce(:-)
when "*"
stack << stack.pop(2).reduce(:*)
when "/"
stack << stack.pop(2).reduce(:/)
else
stack << char.to_i
end
end
stack.first
end
end
```
Observation: ......F
Failures:
1) RPNCalculator#evaluate with an invalid expression raises an error if the input does not follow the rules
Failure/Error: expect { calculator.evaluate("5 + 3") }.to raise_error(ArgumentError)
expected ArgumentError but nothing was raised
# ./rpn_calculator_spec.rb:35:in `block (4 levels) in <top (required)>'
Finished in 0.06175 seconds (files took 0.48972 seconds to load)
7 examples, 1 failure
Failed examples:
rspec ./rpn_calculator_spec.rb:34 # RPNCalculator#evaluate with an invalid expression raises an error if the input does not follow the rules
ですが「5 + 3」を入力したときの実行結果がテストコードと食い違ってしまっており、テストが失敗します。
するとテスト結果を踏まえて以下のように
Thought:The code is failing the test for an invalid expression. I need to check the validate function to see if it's correctly identifying invalid expressions.
validate関数を改善する必要があると考察し、改善案を提示します。
Action: RSpec Executor
Action Input:
```
class RPNCalculator
def validate(input)
input_array = input.split(" ")
operators = ["+", "-", "*", "/"]
input_array.each do |char|
if !operators.include?(char) && !char.match?(/\d+/)
return false
end
end
return true
end
def evaluate(input)
raise ArgumentError unless validate(input)
stack = []
input.split.each do |char|
case char
when "+"
stack << stack.pop(2).reduce(:+)
when "-"
stack << stack.pop(2).reduce(:-)
when "*"
stack << stack.pop(2).reduce(:*)
when "/"
stack << stack.pop(2).reduce(:/)
else
stack << char.to_i
end
end
stack.first
end
end
describe RPNCalculator do
describe "#validate" do
let(:calculator) { RPNCalculator.new }
context "with a valid expression" do
it "returns true" do
expect(calculator.validate("5 3 +")).to eq(true)
end
end
context "with an invalid expression" do
it "returns false for invalid characters" do
expect(calculator.validate("5 + 3")).to eq(false)
end
it "returns false for too many operators" do
expect(calculator.validate("5 3 + +")).to eq(false)
end
it "returns false for too many operands" do
expect(calculator.validate("5 3 2 +")).to eq(false)
end
end
end
end
```
生成したコードを更に確認するよう、独自にテストコードも追加しているようです。
Observation: .FFF......F
Failures:
1) RPNCalculator#validate with an invalid expression returns false for invalid characters
Failure/Error: DEFAULT_FAILURE_NOTIFIER = lambda { |failure, _opts| raise failure }
expected: false
got: true
(compared using ==)
Diff:
@@ -1 +1 @@
-false
+true
# ./rpn_calculator.rb:47:in `block (4 levels) in <top (required)>'
2) RPNCalculator#validate with an invalid expression returns false for too many operators
Failure/Error: DEFAULT_FAILURE_NOTIFIER = lambda { |failure, _opts| raise failure }
expected: false
got: true
(compared using ==)
Diff:
@@ -1 +1 @@
-false
+true
# ./rpn_calculator.rb:51:in `block (4 levels) in <top (required)>'
3) RPNCalculator#validate with an invalid expression returns false for too many operands
Failure/Error: DEFAULT_FAILURE_NOTIFIER = lambda { |failure, _opts| raise failure }
expected: false
got: true
(compared using ==)
Diff:
@@ -1 +1 @@
-false
+true
# ./rpn_calculator.rb:55:in `block (4 levels) in <top (required)>'
4) RPNCalculator#evaluate with an invalid expression raises an error if the input does not follow the rules
Failure/Error: expect { calculator.evaluate("5 + 3") }.to raise_error(ArgumentError)
expected ArgumentError but nothing was raised
# ./rpn_calculator_spec.rb:35:in `block (4 levels) in <top (required)>'
Finished in 0.06958 seconds (files took 0.34411 seconds to load)
11 examples, 4 failures
Failed examples:
rspec './rpn_calculator.rb[1:1:2:1]' # RPNCalculator#validate with an invalid expression returns false for invalid characters
rspec './rpn_calculator.rb[1:1:2:2]' # RPNCalculator#validate with an invalid expression returns false for too many operators
rspec './rpn_calculator.rb[1:1:2:3]' # RPNCalculator#validate with an invalid expression returns false for too many operands
rspec ./rpn_calculator_spec.rb:34 # RPNCalculator#evaluate with an invalid expression raises an error if the input does not follow the rules
今度は異常系が軒並み失敗してしまいます。
この結果を踏まえてLLMが更にコードを変更したところ・・・
Thought:The validate function is not correctly identifying invalid expressions. I need to update the function to correctly identify invalid expressions.
Action: RSpec Executor
Action Input:
```
class RPNCalculator
def validate(input)
input_array = input.split(" ")
operators = ["+", "-", "*", "/"]
operand_count = 0
operator_count = 0
input_array.each do |char|
if !operators.include?(char) && !char.match?(/\d+/)
return false
elsif operators.include?(char)
operator_count += 1
return false if operator_count > operand_count - 1
else
operand_count += 1
end
end
return false if operand_count - operator_count != 1
return true
end
def evaluate(input)
raise ArgumentError unless validate(input)
stack = []
input.split.each do |char|
case char
when "+"
stack << stack.pop(2).reduce(:+)
when "-"
stack << stack.pop(2).reduce(:-)
when "*"
stack << stack.pop(2).reduce(:*)
when "/"
stack << stack.pop(2).reduce(:/)
else
stack << char.to_i
end
end
stack.first
end
end
describe RPNCalculator do
describe "#validate" do
let(:calculator) { RPNCalculator.new }
context "with a valid expression" do
it "returns true" do
expect(calculator.validate("5 3 +")).to eq(true)
end
end
context "with an invalid expression" do
it "returns false for invalid characters" do
expect(calculator.validate("5 + 3")).to eq(false)
end
it "returns false for too many operators" do
expect(calculator.validate("5 3 + +")).to eq(false)
end
it "returns false for too many operands" do
expect(calculator.validate("5 3 2 +")).to eq(false)
end
it "returns false for too few operands" do
expect(calculator.validate("5 +")).to eq(false)
end
end
end
end
```
無事テストが通りました!🎉
Observation: ............
Finished in 0.01266 seconds (files took 0.27404 seconds to load)
12 examples, 0 failures
Thought:The code is passing all tests now. I can submit the final code as the solution to the problem.
Final Answer:
```
class RPNCalculator
def validate(input)
input_array = input.split(" ")
operators = ["+", "-", "*", "/"]
operand_count = 0
operator_count = 0
input_array.each do |char|
if !operators.include?(char) && !char.match?(/\d+/)
return false
elsif operators.include?(char)
operator_count += 1
return false if operator_count > operand_count - 1
else
operand_count += 1
end
end
return false if operand_count - operator_count != 1
return true
end
def evaluate(input)
raise ArgumentError unless validate(input)
stack = []
input.split.each do |char|
case char
when "+"
stack << stack.pop(2).reduce(:+)
when "-"
stack << stack.pop(2).reduce(:-)
when "*"
stack << stack.pop(2).reduce(:*)
when "/"
stack << stack.pop(2).reduce(:/)
else
stack << char.to_i
end
end
stack.first
end
end
```
> Finished chain.
最終的には独自に追加したテストコードを除いたバージョンを提出してくれました。独自にテストコードを追加して動作検証する動きをするとは、なかなか頭が良い・・・。
考察
テストコードは通ったものの、生成されたコード自体は綺麗なものとは言えないし脆弱である。
大きめのテストコードを与えると最大トークン数をすぐオーバーしてしまうので実用性は良くない。
運が悪いと延々とテストの通らないコードを生成し続けてしまう(最大トークン数制限オーバーで強制終了する)。
生成されたコードが汚い件については、リファクタリング原則をコンテキストとして与えた上で、テストが通るように改善し続けるよう指示することで自動化できるかも知れません。
とはいえ前回は結局テストを通すようなコードを生成できなかったのに対して、今回はテストが通るところまで自動化できたので、少しは進歩があったかなーと思います。
現場からは以上です。