![見出し画像](https://assets.st-note.com/production/uploads/images/122770089/rectangle_large_type_2_08c1df244629f117304e94277b3a4424.png?width=1200)
MacBook Proでcalm2-7b-chatを試してみる
以前、WSL2で試してみたのですが、今回はMacBook Proで試してみましょう。
試したマシンは、MacBook Pro M3 Proチップ、メモリ18GBです。
準備
python3 -m venv calm2
cd $_
source bin/activate
pip install。
pip install -U pip
pip install torch
pip install transformers
pip install accelerate
pip list
% pip list
Package Version
------------------ ----------
accelerate 0.24.1
certifi 2023.11.17
charset-normalizer 3.3.2
filelock 3.13.1
fsspec 2023.10.0
huggingface-hub 0.19.4
idna 3.4
Jinja2 3.1.2
MarkupSafe 2.1.3
mpmath 1.3.0
networkx 3.2.1
numpy 1.26.2
packaging 23.2
pip 23.3.1
psutil 5.9.6
PyYAML 6.0.1
regex 2023.10.3
requests 2.31.0
safetensors 0.4.0
setuptools 58.0.4
sympy 1.12
tokenizers 0.15.0
torch 2.1.1
tqdm 4.66.1
transformers 4.35.2
typing_extensions 4.8.0
urllib3 2.1.0
試してみよう
前回とほぼ同じ。M3 Proなので、cudaとあるコードは実行してもアレなので削除しました。
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import time
llm = "cyberagent/calm2-7b-chat"
# トークナイザーとモデルの準備
tokenizer = AutoTokenizer.from_pretrained(llm)
model = AutoModelForCausalLM.from_pretrained(
llm,
# device_map="auto",
torch_dtype="auto"
)
# if torch.cuda.is_available():
# model = model.to("cuda")
RTX 4090と比較するとちょっと引っかかりますね。
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:24<00:00, 12.44s/it]
>>>
どんどん流し込みます。
streamer = TextStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
def build_prompt(user_query, chat_history):
prompt = "USER: " + user_query + "\n"
prompt += "ASSISTANT: "
if chat_history:
prompt = chat_history + "<|endoftext|>\n" + prompt
return prompt
def q(user_query, chat_history):
start = time.process_time()
# 推論の実行
prompt = build_prompt(user_query, chat_history)
input_ids = tokenizer.encode(
prompt,
return_tensors="pt"
)
output_ids = model.generate(
input_ids.to(device=model.device),
max_new_tokens=12000,
do_sample=True,
temperature=0.8,
streamer=streamer,
)
output = tokenizer.decode(
output_ids[0][input_ids.size(1) :],
skip_special_tokens=True
)
# print(output)
chat_history = prompt + output
end = time.process_time()
print(end - start)
return chat_history
では、チャットしましょう。
chat_history = ""
chat_history = q("小学生にでもわかる言葉で教えてください。ドラえもんとはなにか", chat_history)
ドラえもんは、日本のマンガ、アニメのキャラクターで、とても頭がいいロボットとして描かれています。彼は未来からやってきて、未来について知ることができます。彼にはどんなロボットでもポケットから出せる「四次元ポケット」があります。彼の能力はとても強く、また、彼自身もとてもかわいいです。
125.58681
約2分…。ちょっと、いや、かなりゆっくりです。
続きを聞いてみます。
chat_history = q("続きを教えてください", chat_history)
彼は、困っている人を助けたり、友達を助けたりすることが大好きです。彼は、未来についてたくさんのことを教えてくれ、私たちが未来をより良いものにするためのヒントをくれます。また、彼にはさまざまなキャラクターが登場します。例えば、「のび太くん」という男の子が彼の友達で、いつも彼と一緒に冒険します。ドラえもんは、とても人気のあるキャラクターで、世界中で愛されています。
252.80882599999998
ちょっと遅い…。
リソース
メモリの使用量は、14GB前後でした。
![](https://assets.st-note.com/img/1700995775708-eJaRLtHdAK.png?width=1200)
CPUは100%越えとなっていました。
![](https://assets.st-note.com/img/1700996489472-XtzSIydCxR.png?width=1200)