LLMの応答形式を制御する
LLMの応答形式を制御して、確実に決まったJSON形式で応答できるようにする方法を紹介します。
ローカルLLMではプロンプトに回答形式を記述しても、その形式で返答されない場合が多く、そのままではLLMをシステムに組み込むことが難しい。
そこで、outlinesのようなLLMの生成処理をガイドするようなライブラリを用いて制御したいと思います。
今回は、ReActっぽいプロンプトの出力をoutlinesでJSON形式に制御する方法を試してみます。使用するモデルはElyza Llama3で、モデルはvLLMで推論を行います。また、JSON形式はPydanticで定義することにします。
ReActとElyza Llama3については以下を参照。
モデルロードと回答形式の定義
以下のコードでモデルをロードして、回答形式を定義しています。
回答形式は、thought, action, action_inputの3つのキーを持つものとします。
from IPython.display import display, Markdown
from pydantic import BaseModel
from vllm import LLM, SamplingParams
from outlines.serve.vllm import JSONLogitsProcessor
from pydantic import BaseModel, Field
from typing import List
import pprint
import json
model_id = 'elyza/Llama-3-ELYZA-JP-8B'
llm = LLM(
model=model_id,
dtype='half',
quantization="bitsandbytes",
load_format="bitsandbytes",
gpu_memory_utilization=0.8
)
class ReAct(BaseModel):
thought: str
action: str
action_input: str
LLMの出力の制御
outlinesでは、LLMが出力した結果を回答形式に合うようにガイドしながら生成を行う制御を行います。vLLMでは、そのガイド処理がlogit_processorsとして受け取れるようになったいます。
schema = ReAct.model_json_schema()
json_logits_processor = JSONLogitsProcessor(schema, llm)
def generate(prompts):
sampling_params = SamplingParams(
temperature=0.3, top_p=0.95, max_tokens=1024,
logits_processors=[json_logits_processor]
)
ret = llm.generate(prompts, sampling_params=sampling_params)
return [r.outputs[0].text for r in ret]
def display_header(text):
display(Markdown(f'**{text}**\n'))
def display_content(text):
display(Markdown(f'```\n{text}\n```'))
上記コードでは、定義した回答形式をJSONLogitsProcessorとしてガイド処理を構築します。その後、テキスト生成時にそのガイド処理を使うようにlogit_processorsとしてLLMに渡します。
プロンプトの定義
「日本の缶コーヒーの市場規模はいくらかを推定する」方法をLLMに検討させることにします。また、プロンプトのテンプレートはLlama3の形式に従って構築します。
DEFAULT_SYSTEM_PROMPT = """
あなたは親切で、礼儀正しく、正直なアシスタントです。
すべて日本語で返答してください。
"""
def get_prompt(message: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
return f'''<|start_header_id|>system<|end_header_id|>
{system_prompt}
<|eot_id|>
<|start_header_id|>user<|end_header_id|>
{message}
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>'''
question = '''日本の缶コーヒーの市場規模はいくらかを推定するために下記の形式で考えてください。
Thought: 次に何をすべきか常に考える
Action: 次に何をすべきか決定する
Action Input: 行動の入力
以下のjson スキーマを使用して回答する必要があります。
'''
question_with_schema = f'{question}{ReAct.schema_json()}'
prompt = get_prompt(question_with_schema)
display_header("Prompt:")
display_content(prompt)
構築したプロンプト
Prompt:
<|start_header_id|>system<|end_header_id|>
あなたは親切で、礼儀正しく、正直なアシスタントです。
すべて日本語で返答してください。
<|eot_id|>
<|start_header_id|>user<|end_header_id|>
日本の缶コーヒーの市場規模はいくらかを推定するために下記の形式で考えてください。
Thought: 次に何をすべきか常に考える
Action: 次に何をすべきか決定する
Action Input: 行動の入力
以下のjson スキーマを使用して回答する必要があります。
{"properties": {"thought": {"title": "Thought", "type": "string"}, "action": {"title": "Action", "type": "string"}, "action_input": {"title": "Action Input", "type": "string"}}, "required": ["thought", "action", "action_input"], "title": "ReAct", "type": "object"}
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
生成の実行
ここまで構築したプロンプトを使って生成処理を実行します。
display_header("Answer, With json schema enforcing:")
result = generate(prompt)
print()
pprint.pprint(json.loads(result[0]))
生成結果
Answer, With json schema enforcing:
Processed prompts: 100%|██████████| 1/1 [00:09<00:00, 9.94s/it, est. speed input: 20.23 toks/s, output: 12.38 toks/s]
{'action': '市場規模の定義を明確化し、調査の方向性を決める',
'action_input': '日本の缶コーヒー市場規模の推定に必要なデータを集めるため、缶コーヒー関連の統計データや市場調査レポートを探す',
'thought': '日本の缶コーヒー市場規模を推定するためには、まず市場規模の定義を明確化し、調査の方向性を決める必要がある。'}
まとめ
outlinesでJSON形式の回答に制御することができることを確認しました。outlinesはvLLMだけでなくtransformersでも適用することができますが、今回はプロダクトとしてローカルLLMを提供する際によく選択されるvLLMを使った方法を紹介しました。