見出し画像

ローカルマルチモーダルを簡単に使えるAPIを公開。LLaVA-Next(旧1.6)でAPIサーバを構築


始めに

OpenAIやGoogleのAIサービスはマルチモーダル対応が当たり前のようにできます。ローカルLLMでもいくつかマルチモーダルに対応したモデルがありました。めぐチャンネルでも過去にMinigpt-4によるAPIサーバの構築を試しましたが、実用的に使えるかと言うと若干の疑問があったのも事実です。今回は1月末に公開されたLLaVA-NEXT(旧-1.6)で実用に耐えるローカルマルチモーダルLLMのAPIを独自に開発しました。LLaVAは以前から高い評価を得ていましたし、日本語も使えて便利ですが非商用でした。v1.6(NEXTに名称変更)からは商用利用も可能になりました。APIサーバはOllamaやllama.cppもLLaVA対応ができるとされているので、あえて独自APIを開発する必要は無いのですが、ドキュメントを見ながらAPIを解析してテストをする手間と独自に開発する手間では大して差は無いですし、今回は時間もなかったことから慣れている手法で画像アップロードとチャットエンドポイントを実装してAPI化しました。

GitHubのデモは面倒

いくつもサーバ機能を立ち上げながら動かします。注意深く行えばそれほど困難ではありません。以下の記事にデモの立ち上げ方を記載しています。

簡単に使えるAPIサーバがほしい

前述のように、時間も無いということで、画像アップロードとチャット機能だけに縛った簡単なAPIサーバを実装しています。LLaVAのオリジナルコードにはChatの過去ログ機能もあるので有効に活用します。

LLaVA-NEXTの導入

GiyHubからクローンします。


git clone https://github.com/haotian-liu/LLaVA.git
cd LLaVA

環境に合わせて構築

Install Packageに従えば簡単に環境は構築できるはずです。トレーニングはしないのでadditional packagesは不要です。

conda create -n llava python=3.10 -y
conda activate llava
pip install --upgrade pip  # enable PEP 660 support
pip install -e .

モデルのダウンロード

最新のLLaVA-NEXTは以下のモデルが準備されています。

liuhaotian/llava-v1.6-vicuna-7b
liuhaotian/llava-v1.6-vicuna-13b
liuhaotian/llava-v1.6-mistral-7b
liuhaotian/llava-v1.6-34b
v1.5のモデルも動きます。
liuhaotian/llava-v1.5-7b
liuhaotian/llava-v1.5-13b
liuhaotian/llava-v1.5-7b-lora
liuhaotian/llava-v1.5-13b-lora

LLaVa-1.5の性能

cliのテスト

デモを動かすのは大変ですが、cli版は簡単です。

python -m llava.serve.cli \
    --model-path liuhaotian/llava-v1.5-7b \
    --image-file "https://llava-vl.github.io/static/images/view.jpg" \
    --load-4bit

このコードは4bit量子化もオプションでつけているのでGPUは小さくても動きます。7Bで8GByte以下です。

cli.pyを改造してAPIサーバ化

cli.pyはシンプルな構造をしています。主要な部分は以下の通り
・コマンドラインの引数を処理する
・モデルをロードする(タイプごとにやり方が違います)
・イメージをアップロードする
・推論(チャット)を行う
cli.pyはクローンしたリポジトリの
LLaVA/llava/serveディレクトリにあります。

cli.pyを改造

なるべく簡単なコードにしたかったので以下を省きました。
・コマンドラインの引数処理
・モデルを固定、ハードコード
cli.pyをコピーしてファイル名をapi_server.pyに変更しFastAPIでラッピングしています。

エンドポンとは2種類のみ

画像アップロード
@app.post("/api/upload_file")

推論(チャット)
@app.post("/api/chatx")

コード

user_dicで複数のクライアントの要求にも答えられるよう、画像とチャット履歴を管理しています。一度登録すると消去の機能は無いので異なるuser_idで何度も使ったあとは再起動でクリアしてください。
モデルはコード内でモデル名を指定しています。変更する場合はコードを修正してください。

import argparse
import torch
import pprint

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path

from PIL import Image

import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer


from fastapi import FastAPI,Form,File, UploadFile 
from fastapi.responses import HTMLResponse,JSONResponse
from pydantic import BaseModel

app = FastAPI()

user_name="test"

user_dic={user_name:{
            "cov":"",
           "image":"",
           }}


model_path="liuhaotian/llava-v1.5-7b"
#model_path= "SakanaAI/EvoVLM-JP-v1-7B"
#model_path="liuhaotian/llava-v1.6-mistral-7b"

model_base=None
load_8bit=False
load_4bit=True
device="cuda"

    
disable_torch_init()

model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=True, device=device)

conv_mode = "chatml_direct" #v1.6-34b
if "llama-2" in model_name.lower():
    conv_mode = "llava_llama_2"
elif "mistral" in model_name.lower():
    conv_mode = "mistral_instruct"
elif "v1" in model_name.lower():
    conv_mode = "llava_v1"
elif "v1.6-34b" in model_name.lower():
    conv_mode = "chatml_direct"
else:
    conv_mode = "llava_v0"
    
#conv_mode = "llava_v1" #v1
image_file="pose2.png"
temperature=0.2
max_new_tokens=512

conv = conv_templates[conv_mode].copy()
roles = conv.roles

def load_image(image_file):
    if image_file.startswith('http://') or image_file.startswith('https://'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    user_dic[user_name]["image"]=image
    return image

async def generate_response(inp: str,conv,image,image_size,image_tensor):
    print("user_dic1=",user_dic)

    if image is not None:
            # first message
            if model.config.mm_use_im_start_end:
                inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
            else:
                inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
            conv.append_message(conv.roles[0], inp)
            image = None
    else:
            # later messages
            conv.append_message(conv.roles[0], inp)
    conv.append_message(conv.roles[1], None)
    user_dic[user_name]["conv"]=conv
    
    prompt = conv.get_prompt()
    print("prompt =",prompt)
    # トークナイザーでpromptをトークン化
    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)

    # モデルで応答を生成
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,  # 画像処理の結果
            image_sizes=image_size,  # 画像サイズ
            do_sample=True if temperature > 0 else False,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            use_cache=True
        )

    # 生成されたトークンをデコードしてテキストに変換
    outputs = tokenizer.decode(output_ids[0]).strip()
    # 応答をconvオブジェクトに追加(任意)
    conv.messages[-1][-1] = outputs
    user_dic[user_name]["conv"]=conv
    print("conv.messages=",conv.messages)

    return outputs

@app.post("/api/upload_file")
async def upload_file(image: UploadFile = File(...), user_name: str = Form(...)):
    image_data =image.file.read()
    image = Image.open(BytesIO(image_data))  # バイナリデータをPIL形式に変換
    image = image.convert("RGB")
    user_dic[user_name]["image"]=image
    image.show()
    return JSONResponse(content={'message': "OK"})

@app.post("/api/chatx")
async def chatx(prompt:str = Form(...), mode: str=  Form(...), user_name: str = Form(...)):
    # 画像を読み込む処理(適宜実装)
    user_input=prompt

    image=user_dic[user_name]["image"]
 
    image_size = image.size
    image_tensor = process_images([image], image_processor, model.config)
    if type(image_tensor) is list:
        image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
    else:
        image_tensor = image_tensor.to(model.device, dtype=torch.float16)
    if mode=="new":
        conv = conv_templates[conv_mode].copy()
        user_dic[user_name]["conv"]=conv
    else:
        conv =user_dic[user_name]["conv"]
        image=None
    # 応答を生成 2回目はimage=Noneにすればいいはず
 
    assistant_response = await generate_response(user_input,conv,image,image_size,image_tensor)
    
    # 応答をJSON形式で返す
    return {"assistant_response": assistant_response}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8011)

サーバの起動

簡略化したので簡単です。引数はありません。LLaVAディレクトリで以下のコマンドで起動できます。

python -m llava.serve.api_server

クライアント側

テスト用に簡単なテスト用クライアントアプリも作成しました。

コード

LLaVAディレクトリにあるpose2.pngを使用する例です。
前半で画像をアップロード
後半で質問を投げかけて答えを得ています。
"mode":"new"で過去ログがクリアされます。

import requests
from io import BytesIO


print("+++++i2i  TEST")
image_file_path="pose2.png"
file_data = open(image_file_path, "rb").read()

files={"image": ("img.png", BytesIO(file_data), "image/png"),}
data= {"user_name":"test",}
# POSTリクエストを送信
url = 'http://0.0.0.0:8011/api/upload_file'
response = requests.post(url, data=data ,files=files)
# レスポンスを表示 
if response.status_code == 200:
    result = response.json()
    print("サーバーからの応答message:", result.get("message"))
          

else:
    print("リクエストが失敗しました。ステータスコード:", response.status_code)

url = "http://0.0.0.0:8011/api/chatx"

prompt="植木鉢の数はいくつ?"
data= {"prompt":prompt,
       "mode":"new",
       "user_name":"test",}
response = requests.post(url, data=data)
if response.status_code == 200:
    result = response.json()
    print("assistant_response:", result.get("assistant_response"))

prompt="空は何色?"
data= {"prompt":prompt,
       "mode":"continue",
       "user_name":"test",}
response = requests.post(url, data=data)
if response.status_code == 200:
    result = response.json()
    print("assistant_response:", result.get("assistant_response"))

pose2.png

まとめ

日本語で普通に会話できるマルチモーダルLLMは用途も広くて便利です。応答速度も申し分なく、精度も上々です。ローカルLLMでマルチモーダルがも利用ができる実用段階になったな、と感じます。