見出し画像

WSL2でTanuki-8x8B-dpo-v1.0を試してみる

「モデルの開発過程や学習データなどを全てオープンにしたコミュニティで構築した大規模言語モデル」らしいTanuki-8x8Bモデルを試してみます。

vLLMによる推論が最推奨とのことなので、手持ちのGPUを2枚使って、AWQのモデルで試します。

使用するPCはドスパラさんの「GALLERIA UL9C-R49」。スペックは
・CPU: Intel® Core™ i9-13900HX Processor
・Mem: 64 GB
・GPU: NVIDIA® GeForce RTX™ 4090 Laptop GPU(16GB)
・GPU: NVIDIA® GeForce RTX™ 4090 (24GB)
・OS: Ubuntu22.04 on WSL2(Windows 11)
です。


1. 準備

仮想環境

python3 -m venv tanuki-8x8b
cd $_
source bin/activate

つづいて、パッケージのインストール。
vLLMです。

この元モデルは独自アーキテクチャを持つモデルであるため、変換のためにAutoAWQライブラリの一部改変が必要となりました。本モデルの変換に必要な改変を施したAutoAWQをこちらで公開しています。

また、同様の理由で本モデルをvLLMを使って推論する際にはvLLMライブラリの一部改変が必要となります。本モデルの推論に必要な改変を施したvLLMをこちらで公開しています。
これを以下のようにソースからビルドし推論に利用してください。

https://huggingface.co/team-hatakeyama-phase2/Tanuki-8x8B-dpo-v1.0-AWQ

とのことなので、こちらのvLLMを使用します。

# vLLM
git clone https://github.com/team-hatakeyama-phase2/vllm.git
cd vllm
LD_LIBRARY_PATH="" MAX_JOBS=16 pip install -e .
cd ..

vLLMのインストール時にtorchやら依存関係必要なものは勝手にインストールされます。

2. 流し込むコード

いつもの vllm向けのコードです。/path/to/query4vllm.pyとして保存します。

import sys
import argparse
from vllm import LLM, SamplingParams
from typing import List, Dict
import time

# argv
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default=None)
parser.add_argument("--no-chat", action='store_true')
parser.add_argument("--no-use-system-prompt", action='store_true')
parser.add_argument("--max-model-len", type=int, default=32768)
parser.add_argument("--tensor-parallel-size", type=int, default=1)
parser.add_argument("--gpu-memory-utilization", type=float, default=0.2)
parser.add_argument("--max-tokens", type=int, default=4096)

args = parser.parse_args(sys.argv[1:])

model_id = args.model_path
if model_id == None:
    exit

is_chat = not args.no_chat
use_system_prompt = not args.no_use_system_prompt
max_new_tokens = args.max_tokens
tensor_parallel_size = args.tensor_parallel_size
max_model_len = args.max_model_len
gpu_memory_utilization = args.gpu_memory_utilization

# トークナイザーとモデルの準備
model = LLM(
    model=model_id,
    dtype="auto",
    trust_remote_code=True,
    tensor_parallel_size=tensor_parallel_size,
    max_model_len=max_model_len,
    #quantization="awq",
    gpu_memory_utilization=gpu_memory_utilization
)
tokenizer = model.get_tokenizer()

DEFAULT_SYSTEM_PROMPT = "あなたは誠実で優秀な日本人のアシスタントです。"

def q(
    user_query: str,
    history: List[Dict[str, str]]=None
) -> List[Dict[str, str]]:
    # generation params
    generation_params = SamplingParams(
        temperature=0.8,
        top_p=0.95,
        top_k=40,
        max_tokens=max_new_tokens,
        repetition_penalty=1.1
    )
    #
    start = time.process_time()
    # messages
    messages = ""
    if is_chat:
        messages = []
        if use_system_prompt:
            messages = [
                {"role": "system", "content": DEFAULT_SYSTEM_PROMPT},
            ]
        user_messages = [
            {"role": "user", "content": user_query}
        ]
    else:
        user_messages = user_query
    if history:
        user_messages = history + user_messages
    messages += user_messages
    # generation prompts
    if is_chat:
        prompt = tokenizer.apply_chat_template(
            conversation=messages,
            add_generation_prompt=True,
            tokenize=False
        )
    else:
        prompt = messages
    input_ids = tokenizer.encode(
        prompt,
        add_special_tokens=True,
    )
    print("--- prompt")
    print(prompt)
    print("--- output")
    # 推論
    outputs = model.generate(
        sampling_params=generation_params,
        prompt_token_ids=[input_ids],
    )
    print(outputs)
    output = outputs[0]
    print(output.outputs[0].text)
    if is_chat:
        user_messages.append(
            {"role": "assistant", "content": output.outputs[0].text}
        )
    else:
        user_messages += output.outputs[0].text
    end = time.process_time()
    ##
    input_tokens = len(output.prompt_token_ids)
    output_tokens = len(output.outputs[0].token_ids)
    total_time = end - start
    tps = output_tokens / total_time
    print(f"prompt tokens = {input_tokens:.7g}")
    print(f"output tokens = {output_tokens:.7g} ({tps:f} [tps])")
    print(f"   total time = {total_time:f} [s]")
    return user_messages

3. 試してみる

AWQのファイルサイズをみると、24GBを超えていますので、tensor-parallel-sizeに2を指定してGPU2枚使用します。CUDA_VISIBLE_DEVICES変数の指定も忘れずに。

では、起動します。

CUDA_VISIBLE_DEVICES=0,1 python -i ~/scripts/query4vllm.py \
    --model-path team-hatakeyama-phase2/Tanuki-8x8B-dpo-v1.0-AWQ \
    --max-model-len 4096 \
    --tensor-parallel-size 2 \
    --gpu-memory-utilization 0.90

聞いてみましょう。

>>> history = q("ドラえもんとはなにか")
--- prompt
<s>あなたは誠実で優秀な日本人のアシスタントです。

### 指示:
ドラえもんとはなにか

### 応答:

--- output

「ドラえもん」は、藤子・F・不二雄によって創作された日本の漫画およびアニメシリーズに登場する架空の猫型ロボットです。以下に、「ドラえもん」について詳しく説明します。

1. 基本情報:
- 原作者: 藤子・F・不二雄(藤本弘)
- 初出: 1969年5月24日号の『小学一年生』から連載が開始されました。
- ジャンル: 児童向けSFコメディ、ファミリー向け

2. キャラクター紹介:
- 主人公: ドラえもん
- 外見: 猫のような形をしたロボット。青色のボディに赤い鼻、四次元ポケットを持っています。
- 特徴: 未来から来たネコ型ロボットで、22世紀の技術を使って様々な道具を提供します。
- 性格: 非常に優しく、困っている人を助けることを喜びます。また、少しおっちょこちょいな一面もあります。

3. ストーリー概要:
- 物語の中心は、小学生の野比のび太とドラえもんの友情や冒険です。
- のび太は勉強もスポーツも苦手な少年ですが、ドラえもんの助けを借りてさまざまな問題を解決していきます。
- 二人は一緒に日常生活の中で起こるトラブルや不思議な出来事に対処しながら成長していく姿が描かれます。

4. テーマとメッセージ:
- 友情の大切さ: ドラえもんとのび太の間にある深い絆や信頼関係が描かれています。
- 自己成長: のび太が困難を乗り越えることで成長する過程を通じて、努力と勇気の重要性が伝えられます。
- 夢と希望: 未来の技術や冒険を通じて、夢を追い求めることの素晴らしさが強調されています。

5. メディア展開:
- アニメ化: 1973年に初めてテレビアニメ化され、その後何度もリメイクされています。
- 映画化: 多くの劇場版アニメーション映画が制作されており、子供から大人まで幅広い層に愛されています。
- その他のメディア: ゲーム、グッズなど、多岐にわたる商品展開が行われています。

「ドラえもん」は単なるエンターテインメント作品にとどまらず、教育的な要素や社会的なメッセージも含んでおり、多くの人々に影響を与え続けています。

team-hatakeyama-phase2/Tanuki-8x8B-dpo-v1.0-AWQより
prompt tokens = 34
output tokens = 579 (29.738660 [tps])
   total time = 19.469606 [s]

これは、とてもいい!

VRAM使用量はこんな感じ。29.9GB(15.1 + (15.5 - 0.7))ほどでした。

1枚目。RTX 4090(24GB)


2枚目。RTX 4090 Laptop GPU(16GB)。0.7GBはビデオで利用

関連

8Bモデルを試してみるは、こちら。

この記事が気に入ったらサポートをしてみませんか?