見出し画像

WSL2でTanuki-8B-dpo-v1.0をvLLMと共に試してみる

「モデルの開発過程や学習データなどを全てオープンにしたコミュニティで構築した大規模言語モデル」らしいTanuki-8Bモデルを試してみます。

使用するPCはドスパラさんの「GALLERIA UL9C-R49」。スペックは
・CPU: Intel® Core™ i9-13900HX Processor
・Mem: 64 GB
・GPU: NVIDIA® GeForce RTX™ 4090 Laptop GPU(16GB)
・GPU: NVIDIA® GeForce RTX™ 4090 (24GB)
・OS: Ubuntu22.04 on WSL2(Windows 11)
です。


1. 準備

仮想環境

python3 -m venv tanuki-vllm
cd $_
source bin/activate

つづいて、パッケージのインストール。

pip install -U vllm

vLLMのインストール時にtorchやら依存関係必要なものは勝手にインストールされます。

2. 流し込むコード

こちらを/path/to/query4vllm.pyとして保存します。

import sys
import argparse
from vllm import LLM, SamplingParams
from typing import List, Dict
import time

# argv
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default=None)
parser.add_argument("--no-chat", action='store_true')
parser.add_argument("--no-use-system-prompt", action='store_true')
parser.add_argument("--max-model-len", type=int, default=32768)
parser.add_argument("--tensor-parallel-size", type=int, default=1)
parser.add_argument("--gpu-memory-utilization", type=float, default=0.2)
parser.add_argument("--max-tokens", type=int, default=4096)

args = parser.parse_args(sys.argv[1:])

model_id = args.model_path
if model_id == None:
    exit

is_chat = not args.no_chat
use_system_prompt = not args.no_use_system_prompt
max_new_tokens = args.max_tokens
tensor_parallel_size = args.tensor_parallel_size
max_model_len = args.max_model_len
gpu_memory_utilization = args.gpu_memory_utilization

# トークナイザーとモデルの準備
model = LLM(
    model=model_id,
    dtype="auto",
    trust_remote_code=True,
    tensor_parallel_size=tensor_parallel_size,
    max_model_len=max_model_len,
    #quantization="awq",
    gpu_memory_utilization=gpu_memory_utilization
)
tokenizer = model.get_tokenizer()

DEFAULT_SYSTEM_PROMPT = "あなたは誠実で優秀な日本人のアシスタントです。"

def q(
    user_query: str,
    history: List[Dict[str, str]]=None
) -> List[Dict[str, str]]:
    # generation params
    generation_params = SamplingParams(
        temperature=0.8,
        top_p=0.95,
        top_k=40,
        max_tokens=max_new_tokens,
        repetition_penalty=1.1
    )
    #
    start = time.process_time()
    # messages
    messages = ""
    if is_chat:
        messages = []
        if use_system_prompt:
            messages = [
                {"role": "system", "content": DEFAULT_SYSTEM_PROMPT},
            ]
        user_messages = [
            {"role": "user", "content": user_query}
        ]
    else:
        user_messages = user_query
    if history:
        user_messages = history + user_messages
    messages += user_messages
    # generation prompts
    if is_chat:
        prompt = tokenizer.apply_chat_template(
            conversation=messages,
            add_generation_prompt=True,
            tokenize=False
        )
    else:
        prompt = messages
    input_ids = tokenizer.encode(
        prompt,
        add_special_tokens=True,
    )
    print("--- prompt")
    print(prompt)
    print("--- output")
    # 推論
    outputs = model.generate(
        sampling_params=generation_params,
        prompt_token_ids=[input_ids],
    )
    print(outputs)
    output = outputs[0]
    print(output.outputs[0].text)
    if is_chat:
        user_messages.append(
            {"role": "assistant", "content": output.outputs[0].text}
        )
    else:
        user_messages += output.outputs[0].text
    end = time.process_time()
    ##
    input_tokens = len(output.prompt_token_ids)
    output_tokens = len(output.outputs[0].token_ids)
    total_time = end - start
    tps = output_tokens / total_time
    print(f"prompt tokens = {input_tokens:.7g}")
    print(f"output tokens = {output_tokens:.7g} ({tps:f} [tps])")
    print(f"   total time = {total_time:f} [s]")
    return user_messages

3. 試してみる

config.json のパラメータ値を確認しつつ、起動します。

CUDA_VISIBLE_DEVICES=0 python -i ~/scripts/query4vllm.py \
    --model-path weblab-GENIAC/Tanuki-8B-dpo-v1.0 \
    --max-model-len 8192 \
    --gpu-memory-utilization 0.99

聞いてみましょう。

>>> history = q("ドラえもんとはなにか")
--- prompt
<s>あなたは誠実で優秀な日本人のアシスタントです。

### 指示:
ドラえもんとはなにか

### 応答:

--- output

「ドラえもん」は、藤子・F・不二雄によって創作された日本の漫画およびアニメシリーズに登場する架空のキャラクターです。正式名称は「野比のび太(のび太)」という少年ですが、彼の友人であるロボット猫「ドラえもん」が物語の中心人物として描かれています。

ドラえもんは、未来からやってきた高性能なロボットで、主人公であるのび太を助けるためにさまざまな道具を持っています。これらの道具は「ひみつ道具」と呼ばれ、日常生活や問題解決に役立つものから、夢のようなことまで多岐にわたります。例えば、タイムマシンやどこでもドア、四次元ポケットなどがあります。

ドラえもんのキャラクターは、友情や助け合い、成長といったテーマを中心に描かれており、多くの読者に愛されています。また、そのユーモラスなエピソードや感動的なストーリーは、子供から大人まで幅広い層に支持されています。テレビアニメや映画、ゲームなど、さまざまなメディアミックス展開も行われており、日本国内外で非常に人気があります。

総じて、「ドラえもん」は単なるキャラクター以上の存在であり、友情、冒険、学びといった普遍的な価値を伝える象徴的な作品です。

prompt tokens = 34
output tokens = 283 (52.891193 [tps])
   total time = 5.350607 [s]

いままで見てきたなかで一番まともです。中華系によってない。これは質問しがいがあります。

>>> history = q("ドラえもんの好きな食べ物と、その理由を教えてください。")

ドラえもんが好きな食べ物としてよく挙げられるのは「どら焼き」です。どら焼きはふんわりとした生地に甘いあんこが挟んだもので、ドラえもんののんびりとした性格や優しさにぴったりな温かみのあるスイーツです。また、どら焼きは手軽に食べられるため、忙しい日常の中でも気軽に楽しむことができる点も魅力の一つです。さらに、ドラえもんの仲間たち(特に野比のび太)がよくどら焼きを食べているシーンもあり、それが物語の中で親しみやすさを増しています。このように、どら焼きはドラえもんのキャラクター性やストーリーに深く結びついているため、多くのファンから愛されています。

すばらしい。

関連

この記事が気に入ったらサポートをしてみませんか?