WSL2でlocal-gemmaを試してみる
27Bモデルのgemma 2がRTX 4090(24GB)でもロードできる?と噂のlocal-gemmaを試してみます。
使用するPCはドスパラさんの「GALLERIA UL9C-R49」。スペックは
・CPU: Intel® Core™ i9-13900HX Processor
・Mem: 64 GB
・GPU: NVIDIA® GeForce RTX™ 4090 Laptop GPU(16GB)・GPU: NVIDIA® GeForce RTX™ 4090 (24GB)
・OS: Ubuntu22.04 on WSL2(Windows 11)
です。
1. 準備
python3 -m venv local-gemma
cd $_
source bin/activate
パッケージのインストール。
pip install local-gemma"[cuda]"
2. 流し込むコード
以下の内容を/path/to/query4local-gemma.pyとして保存します。
モデルロード時のオプションとして
--preset: メモリ最適オプション。[auto | memory_extreme | memory | exact] から一つ選択。デフォルトは auto
を指定できるようにしています。
import sys
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
from local_gemma import LocalGemma2ForCausalLM
from typing import List, Dict
import time
# argv
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default=None)
parser.add_argument("--tokenizer-path", type=str, default=None)
parser.add_argument("--no-chat", action='store_true')
#parser.add_argument("--no-use-system-prompt", action='store_true')
parser.add_argument("--max-tokens", type=int, default=256)
parser.add_argument("--preset", type=str, default="auto")
args = parser.parse_args(sys.argv[1:])
model_id = args.model_path
if model_id == None:
exit
is_chat = not args.no_chat
#use_system_prompt = not args.no_use_system_prompt
use_system_prompt = False
max_new_tokens = args.max_tokens
preset = args.preset
tokenizer_id = model_id
if args.tokenizer_path:
tokenizer_id = args.tokenizer_path
# トークナイザーとモデルの準備
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_id,
#trust_remote_code=True
)
model = LocalGemma2ForCausalLM.from_pretrained(
model_id,
#torch_dtype="auto",
torch_dtype=torch.bfloat16,
#device_map="cuda",
device_map="auto",
low_cpu_mem_usage=True,
#trust_remote_code=True,
preset=preset
)
#if torch.cuda.is_available():
# model = model.to("cuda")
streamer = TextStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
DEFAULT_SYSTEM_PROMPT = "あなたは誠実で優秀な日本人のアシスタントです。"
def q(
user_query: str,
history: List[Dict[str, str]]=None
) -> List[Dict[str, str]]:
# generation params
generation_params = {
"do_sample": True,
"temperature": 0.8,
"top_p": 0.95,
"top_k": 40,
"max_new_tokens": max_new_tokens,
"repetition_penalty": 1.1,
}
#
start = time.process_time()
# messages
messages = ""
if is_chat:
messages = []
if use_system_prompt:
messages = [
{"role": "system", "content": DEFAULT_SYSTEM_PROMPT},
]
user_messages = [
{"role": "user", "content": user_query}
]
else:
user_messages = user_query
if history:
user_messages = history + user_messages
messages += user_messages
# generation prompts
if is_chat:
prompt = tokenizer.apply_chat_template(
conversation=messages,
add_generation_prompt=True,
tokenize=False
)
else:
prompt = messages
input_ids = tokenizer.encode(
prompt,
add_special_tokens=True,
return_tensors="pt"
)
print("--- prompt")
print(prompt)
print("--- output")
# 推論
output_ids = model.generate(
input_ids.to(model.device),
streamer=streamer,
**generation_params
)
output = tokenizer.decode(
output_ids[0][input_ids.size(1) :],
skip_special_tokens=True
)
if is_chat:
user_messages.append(
{"role": "assistant", "content": output}
)
else:
user_messages += output
end = time.process_time()
##
input_tokens = len(input_ids[0])
output_tokens = len(output_ids[0][input_ids.size(1) :])
total_time = end - start
tps = output_tokens / total_time
print(f"prompt tokens = {input_tokens:.7g}")
print(f"output tokens = {output_tokens:.7g} ({tps:f} [tps])")
print(f" total time = {total_time:f} [s]")
return user_messages
3. 試してみる
gemma 2の27bモデルを指定して、実行します。
CUDA_VISIBLE_DEVICES=0 python -i ~/scripts/query4local-gemma.py --model-path google/gemma-2-27b-it
推論結果はgemma 2と変わらず。
>>> history = q("ドラえもんとはなにか")
--- prompt
<bos><start_of_turn>user
ドラえもんとはなにか<end_of_turn>
<start_of_turn>model
--- output
ドラえもんは、藤子・F・不二雄による日本の漫画作品です。
**概要:**
* **ジャンル:** SFコメディ漫画
* **主人公:** ドラえもん - 未来からやってきた青い猫型ロボット
* **舞台:** 現代の日本(主に東京郊外)
* **あらすじ:** のび太という少年に未来からやってきたドラえもんが、不思議な道具を使ってのび太を助ける物語。
**特徴:**
* **未来のガジェット:** ドラえもんは四次元ポケットから様々な未来の道具を取り出して、のび太や仲間たちの冒険を手助けします。
* **ユーモアと感動:** ドタバタとした笑いと、友情や家族愛など温かいテーマが描かれています。
* **人気:** 長年愛され続けている国民的漫画で、アニメ化、映画化もされています。
**代表的な道具:**
* **どこでもドア:** ドアを開くと、 anywhere に行ける!
* **タケコプター:** 背中に装着すると、空を飛べる!
* **タイムマシン:** 時間旅行ができる!
* **翻訳こんにゃく:** 話す言葉が全て翻訳できる!
**影響:**
ドラえもんは、世代を超えて
prompt tokens = 15
output tokens = 256 (13.102794 [tps])
total time = 19.537817 [s]
>>>
メモリ最適オプション別に評価した結果が以下。
memory_extreme: ロード直後は16.6GB程。推論すると17.0GB程へ。12.7トークン/秒
memory: memory_extremeと変わらず
exact: 20.0GB。オリジナルと精度合わせるため?か、いくら待てども応答がなく… ctrl + c
27bモデル、RTX 4090(24GB)でロードできました。推論速度も12.7トークン/秒とスムースです。
local-gemmaのコードと初回起動時の動きを見るに、その場でINT4に変換している模様。なるほどね。
この記事が気に入ったらサポートをしてみませんか?