WSL2でjapanese-stablelm-instruct-beta-70BのGPTQを試してみる
Stability AI Japanからリリースされた日本語LLM「Japanese Stable LM Beta」のjapanese-stablelm-instruct-beta-70Bを試してみます。
70Bはメモリに乗りきらないのでは…と思いつつ、やってみなきゃわからない!ということでやってみました。
使用するPCは、GALLERIA UL9C-R49(RTX 4090 laptop 16GB)、Windows 11+WSL2、PCメモリは64GBです。
準備
GPTQを使用するので、auto-gptqとoptinumが追加となっています。
python3 -m venv stablelm
cd $_
source bin/activate
pip install torch transformers
pip install sentencepiece protobuf accelerate
pip install auto-gptq optimum
pythonを実行して…
python
>>>
ちょこちょこっと修正したcodeを流し込みます。
反応が返ってくるまでの所要時間を見たかったのでtimeパッケージをimportして、問合せの関数の最初と最後に時刻取得の処理を追加しています。
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
llm = "TheBloke/japanese-stablelm-instruct-beta-70B-GPTQ"
# トークナイザーとモデルの準備
tokenizer = AutoTokenizer.from_pretrained(
llm,
use_fast=True
)
model = AutoModelForCausalLM.from_pretrained(
llm,
device_map="auto",
trust_remote_code=False,
revision="main"
)
if torch.cuda.is_available():
model = model.to("cuda")
さて、メモリにロードできるから、とタスクマネージャを見ていたところ、右肩上がりにGPUのメモリが使用されていく。専用メモリ16GBに直ぐに使い果たしたので「こりゃ、落ちてしまう」と思ったら、共有メモリもガンガン使用していき、ちょうど20.0GBで停止して、合計35.7GBほどでプロンプトが返ってきました。
これならば、応答は遅いだろうけれども動いてしまうのでは? と期待をしつつ、続きを流し込みます。
def build_prompt(user_query):
sys_msg = "<s>[INST] <<SYS>>\nあなたは役立つアシスタントです。<<SYS>>\n\n"
prompt = sys_msg + user_query + " [/INST] "
return prompt
# this is for reproducibility.
# feel free to change to get different result
seed = 23
torch.manual_seed(seed)
def q(user_query):
start = time.process_time()
# 推論の実行
prompt = build_prompt(user_query)
input_ids = tokenizer.encode(
prompt,
add_special_tokens=False,
return_tensors="pt"
)
output_ids = model.generate(
input_ids.to(device=model.device),
max_new_tokens=128,
temperature=0.99,
top_p=0.95,
do_sample=True,
)
output = tokenizer.decode(
output_ids[0][input_ids.size(1) :],
skip_special_tokens=True
)
print(output)
end = time.process_time()
print(end-start)
return output
聞いてみる
では、さっそく聞いてみましょう。
>>> output = q("小学生にでもわかる言葉で教えてください。 ドラえもんとはなにか")
ドラえもんは、「ドラえもん」「のび太」「しずか」「ジャイアン」「スネ夫」といった人物が登場する漫画のタイトルです。
617.6010456
>>>
617秒かかって回答きました。文字数(トークン数ではない)でみると、57文字だから1文字あたり10.8秒。お、おう・・・
再チャレンジ。
>>> output = q("小学生にでもわかる言葉で教えてください。 ドラえもんとはなにか")
ドラえもんは、藤子不二雄による大人気マンガ、そしてアニメの主人公です。ドラえもんは擬人化されたネコで、未来の世界からタイムマシンでやってきて、ドラミチャン、ドラ美ちゃん、小池さん、おののののののののの猫の猫の猫の猫の猫の相棒に
896.6318621000001
>>>
896秒…。
うーん、触れてはいけない秘密を話そうとしてしまって、何かが起きた?ような回答に読める。そういうことか!(ちがう
40GBか80GBあるGPUであれは、普通に動くのかしら。