![見出し画像](https://assets.st-note.com/production/uploads/images/112375111/rectangle_large_type_2_488b3e3a17a62d22b31514648a9e45e8.png?width=1200)
無料版Colabでrinna/bilingual-gpt-neox-4b-instruction-ppoを動かす
サンプルコードのままだとメモリ(RAM)が足りなくて無料版のGoogle Colabだとクラッシュして動かなかったので、8bitで読み込んで動かしました。
ランタイムのタイプはGPUを選びます。
![](https://assets.st-note.com/img/1690990614680-j0vssmrT4C.png?width=1200)
pipでbitsandbytesとaccelerateを読み込んでおきます。
!pip install transformers sentencepiece bitsandbytes accelerate
プロンプトの形式は同じにしておきます。
prompt = [
{
"speaker": "ユーザー",
"text": "Hello, you are an assistant that helps me learn Japanese."
},
{
"speaker": "システム",
"text": "Sure, what can I do for you?"
},
{
"speaker": "ユーザー",
"text": "VTuberの魅力について教えてください。"
}
]
prompt = [
f"{uttr['speaker']}: {uttr['text']}"
for uttr in prompt
]
prompt = "\n".join(prompt)
prompt = (
prompt
+ "\n"
+ "システム: "
)
print(prompt)
モデルを読み込むところで、load_in_8bit=Trueを付けます。
bitandbytesを使っているとmode.to("cuda")は使えないよ! とエラーが出るので、そこは削除しておきます。
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("rinna/bilingual-gpt-neox-4b-instruction-ppo", use_fast=False)
model = AutoModelForCausalLM.from_pretrained("rinna/bilingual-gpt-neox-4b-instruction-ppo", loa d_in_8bit=True)
# if torch.cuda.is_available():
# model = model.to("cuda")
token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
with torch.no_grad():
output_ids = model.generate(
token_ids.to(model.device),
max_new_tokens=512,
do_sample=True,
temperature=1.0,
top_p=0.85,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):])
print(output)
8bitの回答
VTuberは、主に若いバーチャルな存在としてバーチャルリアリティを利用して、ゲームや音楽、トークなどの活動を行う日本のサブカルチャーです。</s>
ちなみに、load_in_8bitの部分をload_in_4bit=Trueにするとより省VRAMで動きますが、なんだか気持ち、回答精度がポンコツになった印象を受けます(ちゃんとベンチマークとってみたさあります)。
4bitの回答
VTuberは、バーチャルキャラクターによるライブストリーミングプラットフォームです。 VTuberは、自分自身の個人的な経験をソーシャルメディアに投稿して、他のVTuberファンと共有することができます。</s>
補足
なお、直接この内容とは関係ないのですが、rinnaの3.6bや4bのモデルを読み込むときは、tokenizerのところでuse_fast=Falseを付けないとエラーが出ることがあります。
tokenizer = AutoTokenizer.from_pretrained("rinna/bilingual-gpt-neox-4b-instruction-ppo", use_fast=False)
他のモデルだと動くのにrinnaだと動かない、みたいなコードがあるときは、ここを疑うと良いです。