
WSL2でChatRWKV (RWKV-5-World-1B5-v2-20231025-ctx4096)を試してみる
RWKV-5-World-1B5-v2-20231025-ctx4096 のモデルを試してみます。
使用するPCは、GALLERIA UL9C-R49(RTX 4090 laptop 16GB)、Windows 11+WSL2です。
事前準備
python3 -m venv rwkv5
cd $_
source bin/activate
続いて、パッケージインストール。
pip install torch pynvml rwkv Ninja gradio
pip listはこんな感じです。
$ pip list
Package Version
------------------------- ------------
aiofiles 23.2.1
aiohttp 3.8.6
aiosignal 1.3.1
altair 5.1.2
annotated-types 0.6.0
anyio 3.7.1
async-timeout 4.0.3
attrs 23.1.0
certifi 2023.7.22
charset-normalizer 3.3.2
click 8.1.7
colorama 0.4.6
contourpy 1.2.0
cycler 0.12.1
exceptiongroup 1.1.3
fastapi 0.104.1
ffmpy 0.3.1
filelock 3.13.1
fonttools 4.44.1
frozenlist 1.4.0
fsspec 2023.10.0
gradio 4.3.0
gradio_client 0.7.0
h11 0.14.0
httpcore 1.0.2
httpx 0.25.1
huggingface-hub 0.17.3
idna 3.4
importlib-resources 6.1.1
Jinja2 3.1.2
jsonschema 4.19.2
jsonschema-specifications 2023.7.1
kiwisolver 1.4.5
linkify-it-py 2.0.2
markdown-it-py 2.2.0
MarkupSafe 2.1.3
matplotlib 3.8.1
mdit-py-plugins 0.3.3
mdurl 0.1.2
mpmath 1.3.0
multidict 6.0.4
networkx 3.2.1
ninja 1.11.1.1
numpy 1.26.2
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu12 2.18.1
nvidia-nvjitlink-cu12 12.3.52
nvidia-nvtx-cu12 12.1.105
orjson 3.9.10
packaging 23.2
pandas 2.1.3
Pillow 10.1.0
pip 22.0.2
pydantic 2.5.0
pydantic_core 2.14.1
pydub 0.25.1
Pygments 2.16.1
pynvml 11.5.0
pyparsing 3.1.1
python-dateutil 2.8.2
python-multipart 0.0.6
pytz 2023.3.post1
PyYAML 6.0.1
referencing 0.30.2
requests 2.31.0
rich 13.6.0
rpds-py 0.12.0
rwkv 0.8.20
semantic-version 2.10.0
setuptools 59.6.0
shellingham 1.5.4
six 1.16.0
sniffio 1.3.0
starlette 0.27.0
sympy 1.12
tokenizers 0.14.1
tomlkit 0.12.0
toolz 0.12.0
torch 2.1.0
tqdm 4.66.1
triton 2.1.0
typer 0.9.0
typing_extensions 4.8.0
tzdata 2023.3
uc-micro-py 1.0.2
urllib3 2.1.0
uvicorn 0.24.0.post1
websockets 11.0.3
yarl 1.9.2
requirements.txtにgradio==3.28.1 とあったのですが、上手く動かなかったので、最新のgradio (4.3.0) にして、以下で説明するapp.pyを1行修正しています。
ソースの修正と実行 - app.py
こちらのapp.pyをベースに少し修正します。
$ diff -u app.py.orig app.py
--- app.py.orig 2023-11-15 12:25:30.805741073 +0900
+++ app.py 2023-11-15 10:06:31.649338778 +0900
@@ -123,6 +123,6 @@
clear.click(lambda: None, [], [output])
data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
-demo.queue(concurrency_count=1, max_size=10)
+#demo.queue(concurrency_count=1, max_size=10)
+demo.queue()
demo.launch(share=False)
では、実行しましょう。
python app.py
https://127.0.0.1:7860/ にブラウザでアクセスして試します。

できました。
CLIでもやってみよう
GUIがアレなわたしとしては、プログラミングしやすいようにしておきたいわけで。
こんな感じで query.pyを作ります。app.pyをベースにしてます。
import os, gc, copy, torch
from datetime import datetime
from huggingface_hub import hf_hub_download
from pynvml import *
nvmlInit()
gpu_h = nvmlDeviceGetHandleByIndex(0)
ctx_limit = 2000
title = "RWKV-5-World-1B5-v2-20231025-ctx4096"
os.environ["RWKV_JIT_ON"] = '1'
os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
from rwkv.model import RWKV
model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{title}.pth")
model = RWKV(model=model_path, strategy='cuda fp16')
from rwkv.utils import PIPELINE, PIPELINE_ARGS
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
def generate_prompt(instruction, input=""):
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
if input:
return f"""Instruction: {instruction}
Input: {input}
Response:"""
else:
return f"""User: hi
Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free
to ask any question and I will always answer it.
User: {instruction}
Assistant:"""
def evaluate(
ctx,
token_count=200,
temperature=1.0,
top_p=0.7,
presencePenalty = 0.1,
countPenalty = 0.1,
):
args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
alpha_frequency = countPenalty,
alpha_presence = presencePenalty,
token_ban = [], # ban the generation of some tokens
token_stop = [0]) # stop generation whenever you see any token here
ctx = ctx.strip()
all_tokens = []
out_last = 0
out_str = ''
occurrence = {}
state = None
for i in range(int(token_count)):
out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
for n in occurrence:
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
#
token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
if token in args.token_stop:
break
all_tokens += [token]
for xxx in occurrence:
occurrence[xxx] *= 0.996
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
#
tmp = pipeline.decode(all_tokens[out_last:])
if '\ufffd' not in tmp:
out_str += tmp
#yield out_str.strip()
out_last = i + 1
#
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
del out
del state
gc.collect()
torch.cuda.empty_cache()
#yield out_str.strip()
return out_str
def query(user_query):
print(evaluate(generate_prompt(user_query)))
yield部分をコメントアウトせずに、
def query(user_query):
for out_str evaluate(generate_prompt(user_query)):
print(out_str)
とすると、FF2っぽくなります。
そんな話は横に置き、
>>> query("ドラえもんの登場人物をJSONで")
vram 17171480576 used 8015347712 free 9156132864
以下はドラえもんの登場人物をJSONで表示します。
```json
{
"Alice": "ドラえもんの主人公の妹である『マイマイ』、身長180cm、体重50kg、身体能力を得意としている魔法使い 。小さい頃から人間として育てられ、変身を使用する際はその魔法で自分を変身させることができる。"
}
```
このJSONデータには、ドラえもんの主人公である「Alice」の情報が含まれています。彼女はドラえもんの主人公であり 、身長180cm、体重50kg、身体能力を得意としている魔法使いであり、小さい頃から人間として育てられ、変身を使用す る際はその魔法で
>>>
はい、できました。
メモリの使用量など
メモリの使用量ですが、3.7GBほど。かるい!
