vLLMのOpenAI APIインターフェースサーバーでバッチ推論をさせる
はじめに
vLLMはLLMを高速で推論させるためのプログラムで、非常に便利です。
特に、vLLMはバッチ推論時に真価を発揮し、非常に処理が早いです。
vLLMは標準でOpenAIのAPI互換のサーバー機能もついており、便利なのですが、こちらは1件ずつしかクエリを処理してくれない(?)という課題があります。
大量のクエリを投げて処理したいこともよくあるので、本記事では、バッチ推論用のコードを作りました。
コード
基本的にo1 Proに作ってもらいました。
apply_chat_templateの処理がイマイチだったので、自分で直しました。
ちなみに、本日公開されたo3-mini, o3-mini-highにも頼んでみましたが、vLLM周りのコードをサボってきたので、駄目でした。
サーバー
溜まったクエリをバッチで処理する仕様です。
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Optional
import threading
import time
from queue import Queue
import asyncio
import uuid
from transformers import AutoTokenizer
# vLLM のインポート
from vllm import LLM, SamplingParams
# -----------------------------
# 1. FastAPI アプリの準備
# -----------------------------
app = FastAPI()
# リクエストの待ち行列
request_queue = Queue()
# 推論結果を格納する辞書
# { request_id: {"role": "assistant", "content": "..."} }
results = {}
# vLLM のモデルをロード (例)
model_name = "cyberagent/calm3-22b-chat"
llm = LLM(model=model_name) # 実際のモデルに合わせて指定
tokenizer = AutoTokenizer.from_pretrained(model_name)
# -----------------------------
# 2. Pydantic モデル定義
# (OpenAI ChatCompletion API 風)
# -----------------------------
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
max_tokens: Optional[int] = 128
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.9
# -----------------------------
# 3. メッセージを一つのプロンプトにまとめるヘルパー
# -----------------------------
def messages_to_prompt(messages: List[ChatMessage]) -> str:
"""
簡易的に system / user / assistant の role を文字列にまとめる。
必要に応じてフォーマットを変更してください。
"""
prompt_str = tokenizer.apply_chat_template(messages, tokenize=False,
add_generation_prompt=True)
return prompt_str
# -----------------------------
# 4. バッチ推論を行うスレッド
# -----------------------------
def batch_inference_worker():
"""一定間隔で Queue の中身をバッチ処理し、vLLM にまとめて推論させる。"""
while True:
requests_to_process = []
# キューに溜まっている全リクエストを一旦取り出す
while not request_queue.empty():
data = request_queue.get()
requests_to_process.append(data)
if requests_to_process:
# バッチで処理する「プロンプト」をまとめる
prompts = [r["prompt"] for r in requests_to_process]
# ここでは「最初のリクエストのパラメータ」を代表として使う例。
# 本来はリクエストごとにパラメータが違うかもしれないので要設計
first_req = requests_to_process[0]
sampling_params = SamplingParams(
max_tokens=first_req["max_tokens"],
temperature=first_req["temperature"],
top_p=first_req["top_p"]
)
# vLLM にバッチ推論
outputs = llm.generate(prompts, sampling_params)
# 結果を results に格納
for req_data, output in zip(requests_to_process, outputs):
generated_text = output.outputs[0].text # 1件目のみ取り出す例
results[req_data["request_id"]] = {
"role": "assistant",
"content": generated_text
}
# 次のバッチ処理まで待機 (例: 0.2 秒)
time.sleep(0.2)
# バックグラウンドワーカー起動
threading.Thread(target=batch_inference_worker, daemon=True).start()
# -----------------------------
# 5. エンドポイント: /v1/chat/completions
# OpenAI ChatCompletion 互換 (簡易版)
# -----------------------------
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
# リクエストID を発行
request_id = str(uuid.uuid4())
# messages を一つのプロンプトにまとめる
prompt = messages_to_prompt(request.messages)
# キューに格納するデータを作成
request_data = {
"request_id": request_id,
"prompt": prompt,
"max_tokens": request.max_tokens,
"temperature": request.temperature,
"top_p": request.top_p,
}
request_queue.put(request_data)
# 結果が生成されるまで待つ (タイムアウト等は本来考慮必須)
while request_id not in results:
await asyncio.sleep(0.05)
# 生成されたアシスタントメッセージを取り出す
assistant_msg = results.pop(request_id)
# OpenAI 互換風のレスポンス
return {
"id": request_id,
"object": "chat.completion",
"created": int(time.time()),
"model": request.model,
"choices": [
{
"index": 0,
"message": assistant_msg,
"finish_reason": "stop",
}
]
}
# -----------------------------
# 6. アプリ起動 (開発時)
# -----------------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
client
from openai import OpenAI
api_key = "dummy_key"
client = OpenAI(
base_url="http://localhost:8000/v1/",
api_key=api_key,
)
import concurrent.futures
def get_completion(i):
completion = client.chat.completions.create(
model="cyberagent/calm3-22b-chat",
messages=[
{"role": "system", "content": "あなたはアシスタントです"},
{"role": "user", "content": f"{i}+{i}は?"},
],
)
return completion.choices[0].message.content
with concurrent.futures.ThreadPoolExecutor() as executor:
results = list(executor.map(get_completion, range(10)))
for result in results:
print(result)
結果
数学的な文脈では、0+0は0です。これは基本的な算術のルールに従った結果です。ただし、文脈によっては異なる解釈や結果が生じることもあります。
例えば、物理学や統計学などの分野では、0+0が特定の状況や近似によって異なる値を取る場合もありますが、標準的な数学的ルールでは0です。
もちろん、1 + 1 は 2 です。
答えは4です。
答えは6です。
答えは、8です。
答えは10です。
答えは12です。
答えは14です。
答えは16です。
答えは18です。

まとめ
vllmをOpenAI APIのインターフェースのサーバーで動かす際に、バッチ推論したかったので、コードをchatGPTに作ってもらいました。
このタスクでのコード生成能力は、o1 pro >> o3-mini系 でした。
生成されたコードを少し修正して、実装しました。
細かな動作検証はしていませんが、それっぽく動いてくれました。