ローカルLLMでAgent構築
ローカルLLMを使ってAgentを構築して、ツールの呼び出しまでやってみようと思います。ローカルLLMは制御が難しく、Agentととして動作させるのが難しい印象でしたが、日本語でも性能の高いモデルが登場しているのでトライしてみます。
LLMの準備
vLLMでモデルをデプロイします。
from typing import List, Optional
from pydantic import BaseModel, Field
from vllm import LLM, SamplingParams
from outlines.serve.vllm import JSONLogitsProcessor
from langchain.agents import Tool
import json
model_id = 'tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.1'
vllm_model = LLM(
model=model_id, dtype="float16",
quantization="bitsandbytes", load_format="bitsandbytes",
gpu_memory_utilization=1.0, max_model_len=2048,
)
データモデルとプロンプト
LLMの制御性を高めるために形式を指定して応答できるようにデータモデルを定義します。決められたデータモデルで応答するためのプロンプトを合わせて作ります。
class ThoughtAction(BaseModel):
thought: str = Field(description="考えたこと(例:ツールにより正確な値が得られたので、最終的な答えとする)")
action: Optional[str] = Field(description="使用するツール名", default=None)
action_input: Optional[str] = Field(description="ツールへの入力", default=None)
class AgentStep(BaseModel):
step: ThoughtAction
observation: Optional[str] = Field(description="ツールからの出力結果", default=None)
class AgentResponse(BaseModel):
steps: List[AgentStep]
final_answer: str
prompt_template = """
<|start_header_id|>system<|end_header_id|>
あなたは以下のツールを使用できる指示に忠実で優秀なAIアシスタントです。
利用可能なツール:
{tools}
<|eot_id|>
<|start_header_id|>user<|end_header_id|>
ステップバイステップで考え、質問に回答してください。
- 考えて実行した経緯は実行履歴で与えられます。
- 実行履歴から回答を探してください。
- Observationをよく観察してください。
- 今回の思考は{n}回目です。
- 必ず前回の実行履歴と別の行動してください。
- 以下の形式でJSONを生成してください。
{schema}
思考の例:
Thought: 考えたこと(例:ツールにより正確な値が得られたので、最終的な答えとする)
Action: 行うこと(例:ツールで計算する)
ActionInput: 行動に必要な情報(例:1 + 2)
Observation: 行動の結果(例:3)
{thought}
質問:
{question}
回答:
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
"""
Agentクラス
Agentクラスを定義して、Agentのロジックを実装します。このAgentクラスからツール呼び出しができるようにしておきます。
class StructuredVLLMAgent:
def __init__(self, model, tools: List[Tool] = None):
self.vllm_model = model
self.tools = tools or []
self.tool_map = {tool.name: tool for tool in self.tools}
self.schema = ThoughtAction.model_json_schema()
json_logits_processor = JSONLogitsProcessor(self.schema, self.vllm_model)
self.sampling_params = SamplingParams(
temperature=0.1, top_p=0.95, max_tokens=1024,
repetition_penalty=1.2,
logits_processors=[json_logits_processor]
)
self.prompt_template = prompt_template
def _format_tools(self) -> str:
"""ツール情報のフォーマット"""
return "\n".join([
f"- {tool.name}: {tool.description}"
for tool in self.tools
])
def _execute_tool(self, tool_name: str, tool_input: str) -> str:
"""ツールの実行"""
if tool_name not in self.tool_map:
return f"Error: Tool '{tool_name}' not found"
try:
return str(self.tool_map[tool_name].func(tool_input))
except Exception as e:
return f"Error executing tool: {str(e)}"
def run(self, question: str, max_steps: int = 5) -> AgentResponse:
steps: List[AgentStep] = []
final_answer = None
for n in range(max_steps):
# プロンプトの生成
thought = ''
if steps:
for m, x in enumerate(steps):
thought += f"{m+1}回目の実行履歴:\n"
thought += f'\tThought:{x.step.thought}\n\tAction:{x.step.action}\n\tActionInput:{x.step.action_input}\n\tObservation:{x.observation}\n'
current_prompt = self.prompt_template.format(
tools=self._format_tools(),
question=question,
schema=self.schema,
thought=thought,
n=n+1,
)
outputs = self.vllm_model.generate(
[current_prompt],
self.sampling_params,
)
response_text = outputs[0].outputs[0].text
try:
# JSON応答のパース
step_data = json.loads(response_text)
current_step = ThoughtAction(**step_data)
# ツールの実行が必要な場合
observation = None
if current_step.action and current_step.action_input:
observation = self._execute_tool(
current_step.action,
current_step.action_input
)
else:
observation = current_step.thought
steps.append(AgentStep(
step=current_step,
observation=observation
))
# 最終回答に達したかチェック
if observation is not None:
final_answer = observation
break
except json.JSONDecodeError:
print(f"Error parsing JSON response: {response_text}")
continue
if not final_answer:
final_answer = "最大ステップ数に達しました。"
return AgentResponse(steps=steps, final_answer=final_answer)
Agentの繰り返し実行回数は5回としています。
Agentの作り方やToolはLangchainの方式に一致するように実装していますが、Agentクラスは独自定義となっています。ローカルLLMで利用可能なLangchainのエコシステムを使えると便利なので、このような構成にしています。
Toolも独自定義で実装しても手間は大きく変わらないですが、拡張性を考えるとLangchainに合わせた方が良いでしょう。
Toolの定義
Toolは下記のように計算するだけのFunctionです。
def calculator(expression: str) -> str:
try:
return f'{expression}の計算結果は、{str(eval(expression))}です。'
except Exception as e:
return f"計算エラー: {str(e)}"
tools = [
Tool(
name="calculator",
func=calculator,
description="数式の計算を実行します。入力は数式の文字列です。"
)
]
エージェントの実行
今回は4つのクエリをテストしてみました。
agent = StructuredVLLMAgent(vllm_model, tools=tools)
# テストクエリ
queries = [
"1234 * 5678 を計算してください",
"2024年は令和何年ですか?",
"原価3000円の商品に20%の利益をつけて売りました。定価はいくら?",
"128-256を計算して"
]
# テスト実行
for query in queries:
print(f"\n質問: {query}")
response = agent.run(query)
print("\n回答:")
print(response.final_answer)
print("\n履歴:")
print(json.dumps(response.model_dump(), indent=2, ensure_ascii=False))
実行結果
実行は以下の2つのLLMを使ってみました。
Llama3.1
質問: 1234 * 5678 を計算してください
回答:
1234*5678の計算結果は、7006652です。
質問: 2024年は令和何年ですか?
回答:
2024年は令和6年である
質問: 原価3000円の商品に20%の利益をつけて売りました。定価はいくら?
回答:
3000 * (1+0.2)の計算結果は、3600.0です。
質問: 128-256を計算して
回答:
128 - 256の計算結果は、-128です。
Llama3
質問: 1234 * 5678 を計算してください
回答:
1234 * 5678の計算結果は、7006652です。
質問: 2024年は令和何年ですか?
回答:
2024 - 2019の計算結果は、5です。
質問: 原価3000円の商品に20%の利益をつけて売りました。定価はいくら?
回答:
3000 * (100 / 80)の計算結果は、3750.0です。
質問: 128-256を計算して
回答:
128 - 256の計算結果は、-128です。
Llama3.1は全問正解しているようですが、Llama3の方は間違っているものがありますね。
チューニングデータも違っていると思うので、比較にあまり意味はないかもしれませんが、日本語で8BモデルでAgentが動くことが確認できました。