見出し画像

ローカルLLMでAgent構築


ローカルLLMを使ってAgentを構築して、ツールの呼び出しまでやってみようと思います。ローカルLLMは制御が難しく、Agentととして動作させるのが難しい印象でしたが、日本語でも性能の高いモデルが登場しているのでトライしてみます。

LLMの準備

vLLMでモデルをデプロイします。

from typing import List, Optional
from pydantic import BaseModel, Field
from vllm import LLM, SamplingParams
from outlines.serve.vllm import JSONLogitsProcessor
from langchain.agents import Tool
import json

model_id = 'tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.1'
vllm_model = LLM(
    model=model_id, dtype="float16",
    quantization="bitsandbytes", load_format="bitsandbytes",
    gpu_memory_utilization=1.0, max_model_len=2048,
)

データモデルとプロンプト

LLMの制御性を高めるために形式を指定して応答できるようにデータモデルを定義します。決められたデータモデルで応答するためのプロンプトを合わせて作ります。

class ThoughtAction(BaseModel):
    thought: str = Field(description="考えたこと(例:ツールにより正確な値が得られたので、最終的な答えとする)")
    action: Optional[str] = Field(description="使用するツール名", default=None)
    action_input: Optional[str] = Field(description="ツールへの入力", default=None)

class AgentStep(BaseModel):
    step: ThoughtAction
    observation: Optional[str] = Field(description="ツールからの出力結果", default=None)

class AgentResponse(BaseModel):
    steps: List[AgentStep]
    final_answer: str

prompt_template = """
<|start_header_id|>system<|end_header_id|>
あなたは以下のツールを使用できる指示に忠実で優秀なAIアシスタントです。
利用可能なツール:
{tools}
<|eot_id|>
<|start_header_id|>user<|end_header_id|>
ステップバイステップで考え、質問に回答してください。
- 考えて実行した経緯は実行履歴で与えられます。
- 実行履歴から回答を探してください。
- Observationをよく観察してください。
- 今回の思考は{n}回目です。
- 必ず前回の実行履歴と別の行動してください。
- 以下の形式でJSONを生成してください。
{schema}

思考の例:
    Thought: 考えたこと(例:ツールにより正確な値が得られたので、最終的な答えとする)
    Action: 行うこと(例:ツールで計算する)
    ActionInput: 行動に必要な情報(例:1 + 2)
    Observation: 行動の結果(例:3)

{thought}

質問: 
{question}

回答:
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
"""

Agentクラス

Agentクラスを定義して、Agentのロジックを実装します。このAgentクラスからツール呼び出しができるようにしておきます。

class StructuredVLLMAgent:
    def __init__(self, model, tools: List[Tool] = None):
        self.vllm_model = model
        self.tools = tools or []
        self.tool_map = {tool.name: tool for tool in self.tools}        
        self.schema = ThoughtAction.model_json_schema()
        json_logits_processor = JSONLogitsProcessor(self.schema, self.vllm_model)
        self.sampling_params = SamplingParams(
            temperature=0.1, top_p=0.95, max_tokens=1024,
            repetition_penalty=1.2,
            logits_processors=[json_logits_processor]
        )
        self.prompt_template = prompt_template

    def _format_tools(self) -> str:
        """ツール情報のフォーマット"""
        return "\n".join([
            f"- {tool.name}: {tool.description}"
            for tool in self.tools
        ])

    def _execute_tool(self, tool_name: str, tool_input: str) -> str:
        """ツールの実行"""
        if tool_name not in self.tool_map:
            return f"Error: Tool '{tool_name}' not found"
        try:
            return str(self.tool_map[tool_name].func(tool_input))
        except Exception as e:
            return f"Error executing tool: {str(e)}"

    def run(self, question: str, max_steps: int = 5) -> AgentResponse:
        steps: List[AgentStep] = []
        final_answer = None
        for n in range(max_steps):
            # プロンプトの生成
            thought = ''
            if steps:
                for m, x in enumerate(steps):
                    thought += f"{m+1}回目の実行履歴:\n"
                    thought += f'\tThought:{x.step.thought}\n\tAction:{x.step.action}\n\tActionInput:{x.step.action_input}\n\tObservation:{x.observation}\n'
            current_prompt = self.prompt_template.format(
                tools=self._format_tools(),
                question=question,
                schema=self.schema,
                thought=thought,
                n=n+1,
            )
                        
            outputs = self.vllm_model.generate(
                [current_prompt],
                self.sampling_params,
            )
            response_text = outputs[0].outputs[0].text
            
            try:
                # JSON応答のパース
                step_data = json.loads(response_text)
                current_step = ThoughtAction(**step_data)
                
                # ツールの実行が必要な場合
                observation = None
                if current_step.action and current_step.action_input:
                    observation = self._execute_tool(
                        current_step.action,
                        current_step.action_input
                    )
                else:
                    observation = current_step.thought
                
                steps.append(AgentStep(
                    step=current_step,
                    observation=observation
                ))
                
                # 最終回答に達したかチェック
                if observation is not None:
                    final_answer = observation
                    break
                    
            except json.JSONDecodeError:
                print(f"Error parsing JSON response: {response_text}")
                continue
        
        if not final_answer:
            final_answer = "最大ステップ数に達しました。"
        
        return AgentResponse(steps=steps, final_answer=final_answer)

Agentの繰り返し実行回数は5回としています。
Agentの作り方やToolはLangchainの方式に一致するように実装していますが、Agentクラスは独自定義となっています。ローカルLLMで利用可能なLangchainのエコシステムを使えると便利なので、このような構成にしています。
Toolも独自定義で実装しても手間は大きく変わらないですが、拡張性を考えるとLangchainに合わせた方が良いでしょう。

Toolの定義

Toolは下記のように計算するだけのFunctionです。

def calculator(expression: str) -> str:
    try:
        return f'{expression}の計算結果は、{str(eval(expression))}です。'
    except Exception as e:
        return f"計算エラー: {str(e)}"

tools = [
    Tool(
        name="calculator",
        func=calculator,
        description="数式の計算を実行します。入力は数式の文字列です。"
    )
]

エージェントの実行

今回は4つのクエリをテストしてみました。

agent = StructuredVLLMAgent(vllm_model, tools=tools)

# テストクエリ
queries = [
    "1234 * 5678 を計算してください",
    "2024年は令和何年ですか?",
    "原価3000円の商品に20%の利益をつけて売りました。定価はいくら?",
    "128-256を計算して"
]

# テスト実行
for query in queries:
    print(f"\n質問: {query}")
    response = agent.run(query)
    print("\n回答:")
    print(response.final_answer)
    print("\n履歴:")
    print(json.dumps(response.model_dump(), indent=2, ensure_ascii=False))

実行結果

実行は以下の2つのLLMを使ってみました。

  • Llama3.1

質問: 1234 * 5678 を計算してください
回答:
1234*5678の計算結果は、7006652です。


質問: 2024年は令和何年ですか?
回答:
2024年は令和6年である

質問: 原価3000円の商品に20%の利益をつけて売りました。定価はいくら?
回答:
3000 * (1+0.2)の計算結果は、3600.0です。

質問: 128-256を計算して
回答:
128 - 256の計算結果は、-128です。
  • Llama3

質問: 1234 * 5678 を計算してください
回答:
1234 * 5678の計算結果は、7006652です。

質問: 2024年は令和何年ですか?
回答:
2024 - 2019の計算結果は、5です。

質問: 原価3000円の商品に20%の利益をつけて売りました。定価はいくら?
回答:
3000 * (100 / 80)の計算結果は、3750.0です。

質問: 128-256を計算して
回答:
128 - 256の計算結果は、-128です。

Llama3.1は全問正解しているようですが、Llama3の方は間違っているものがありますね。
チューニングデータも違っていると思うので、比較にあまり意味はないかもしれませんが、日本語で8BモデルでAgentが動くことが確認できました。


いいなと思ったら応援しよう!