WSL2でRWKV-infctx-trainer-LoRAを試してみる...
「Infctx-LoRAに4ビット量子化モードを追加し、単一 24GB GPUで14Bモデルのトレーニングを可能にした」らしいRWKV-infctx-trainer-LoRAを試してみます。
24GBという文字列を見ると反応していますね。はい。
使用するPCはドスパラさんの「GALLERIA UL9C-R49」。スペックは
・CPU: Intel® Core™ i9-13900HX Processor
・Mem: 64 GB
・GPU: NVIDIA® GeForce RTX™ 4090 Laptop GPU(16GB)・GPU: NVIDIA® GeForce RTX™ 4090 (24GB)
・OS: Ubuntu22.04 on WSL2(Windows 11)
です。
1. 準備
環境構築
python3 -m venv rwkv-infctx
cd $_
source bin/activate
リポジトリをクローンします。
git clone https://github.com/OpenMOSE/RWKV-infctx-trainer-LoRA
cd RWKV-infctx-trainer-LoRA
パッケージのインストール。
pip install -r requirements.txt
pip install bitsandbytes
モデルのダウンロード
RWKV 6の14Bモデルで試します。14Bのモデルは、temp-latest-training-modelsにアップロードされており、約2週間の頻度で差し替わっているようです。
今回は、本日(2024/5/30)時点で最新のモデルであるrwkv-x060-14b-world-v2.1-81%trained-20240527-ctx4k.pth を使用します。試すタイミングによってはファイル名(日付の箇所)が変更となっているかと思いますので、適宜読み替えてください。
4bit量子化向けのディレクトリに移動し、huggingface-cliコマンドを用いて、ダウンロードします。
# 移動
cd RWKV-v6-4bit
# RWKVv6 14Bのダウンロード
huggingface-cli download --repo-type model --local-dir ./model --local-dir-use-symlinks True BlinkDL/temp-latest-training-models rwkv-x060-14b-world-v2.1-81%trained-20240527-ctx4k.pth
2. 試してみる - enwiki_10k
続いて、学習です。本家 RWKV/RWKV-infctx-trainer の READMEも参考にしつつ、データセットの準備、設定ファイルの修正を行います。
設定ファイルの修正
本家で紹介されているデータセットである teven/enwiki_10k を使用して、まずは試します。
設定ファイル sample_config.yaml の修正差分は、以下です。
--- a/RWKV-v6-4bit/sample_config.yaml
+++ b/RWKV-v6-4bit/sample_config.yaml
@@ -84,11 +84,11 @@ trainer:
reload_dataloaders_every_n_epochs: 0
default_root_dir: null
model:
- load_model: rwkv-x060-14b-world-v2.1-26%trained-20240501-ctx4k.pth
+ load_model: model/rwkv-x060-14b-world-v2.1-81%trained-20240527-ctx4k.pth
n_embd: -1
n_layer: -1
vocab_size: -1
- ctx_len: 16384
+ ctx_len: 4096
ctx_len_cutoffs: []
ctx_len_warmup_steps: []
lora_r: 16.0 # LoRA Rank
@@ -121,8 +121,8 @@ model:
torch_set_float32_matmul_precision: high
data:
data_path: dataset/
- source: json
- source_data_dir: dataset_src/
+ source: teven/enwiki_10k
+ source_data_dir: data/
source_dataset_params: null
test_split: 0.01
test_split_shuffle: true
学習してみる
では、この設定ファイルを用いて、学習に向けた事前準備、学習データと検証用データセットの作成を行います。
python preload_datapath.py sample_config.yaml
実行後、datasetディレクトリに学習データと検証用データセットが作成されます。
$ find dataset -type f
dataset/train/data-00000-of-00001.arrow
dataset/train/dataset_info.json
dataset/train/state.json
dataset/dataset_dict.json
dataset/test/data-00000-of-00001.arrow
dataset/test/dataset_info.json
dataset/test/state.json
$
続いて、学習です。
CUDA_VISIBLE_DEVICES=0 python lightning_trainer.py fit -c sample_config.yaml
2時間20分ほど経過して、epoch 1が終わりました。設定ファイル上、epoch 100まで回すようになっているので、lossが下がった(下がらなかった)タイミングでctrl+cして、止めます。
なお、チェックポイントは checkpoint ディレクトリに epoch 単位に出力されています。
$ ls -l checkpoint/
total 4595508
-rw-r--r-- 1 user user 1568593860 May 30 20:49 epoch_1_weights.pth
-rw-r--r-- 1 user user 1568593860 May 30 23:10 epoch_2_weights.pth
-rw-r--r-- 1 user user 1568593860 May 31 01:31 epoch_3_weights.pth
$
では、この epoch 3の重みをベースモデルにマージし、推論できるようにしましょう。
CUDA_VISIBLE_DEVICES=1 python merge_lora.py --use-gpu 16 ./model/rwkv-x060-14b-world-v2.1-81%trained-20240527-ctx4k.pth checkpoint/epoch_3_weights.pth enwki_10k_epoch_3.pth
引数は順番に、
(1) --use-gpu: GPUを使う
(2) lora_alphaの値。ここでは16
(3) ベースモデルへのパス
(4) マージするチェックポイントへのパス
(5) マージ後の出力ファイル名
です。
学習中のvram使用量のMAXは 21.8GB、学習に要した時間は、1 epoch たり2時間20分ほどでした。
試してみる
以下の内容を、query.py として保存し、
import os
os.environ["RWKV_JIT_ON"] = '1'
os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS
import sys
import argparse
import torch
from huggingface_hub import hf_hub_download
from typing import List, Dict
import time
# argv
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str)
parser.add_argument("--model-file", type=str)
parser.add_argument("--no-chat", action='store_true')
parser.add_argument("--no-use-system-prompt", action='store_true')
parser.add_argument("--max-tokens", type=int, default=256)
args = parser.parse_args(sys.argv[1:])
model_id = args.model_path
#if model_id == None:
# exit
model_file = args.model_file
if model_file == None:
exit
is_chat = not args.no_chat
use_system_prompt = not args.no_use_system_prompt
max_new_tokens = min(3500, args.max_tokens)
## Download the rwkv model
if model_id:
model_path = hf_hub_download(repo_id=model_id, filename=model_file)
else:
model_path = f"{model_file}"
## Instantiate model from downloaded file
model = RWKV(model=model_path, strategy='cuda fp16')
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
DEFAULT_SYSTEM_PROMPT = "わたしは誠実で優秀な日本人のアシスタントです。"
# generations params
pipeline_args = PIPELINE_ARGS(
temperature=1.0,
top_p=0.85,
top_k=0,
alpha_frequency=0.2,
alpha_presence=0.2,
alpha_decay=0.996,
token_ban=[],
token_stop=[],
chunk_len=256
)
#
def generate_chat_prompt(
conversation: List[Dict[str, str]],
add_generation_prompt=True,
) -> str:
prompt = ""
for message in conversation:
role = message["role"]
content = message["content"].strip().replace('\r\n','\n').replace('\n\n','\n')
if message["role"] == "system":
prompt += f"User: こんにちは\n\nAssistant: {content}\n\n"
else:
prompt += f"{role}: {content}\n\n"
if add_generation_prompt:
prompt += "Assistant:"
return prompt
#
def generate_prompt(
user_query: str,
instruction: str=None,
add_generation_prompt=True,
) -> str:
prompt = ""
prompt += f"Instruction: {instruction}\n\n"
prompt += f"Input: {user_query}\n\n"
if add_generation_prompt:
prompt += f"Output:"
return prompt
# callback function
def print_nolf(outstr):
print(outstr, end="")
def q(
user_query: str,
history: List[Dict[str, str]]=None,
instruction: str=None
) -> List[Dict[str, str]]:
start = time.process_time()
# messages
messages = ""
if is_chat:
messages = []
if use_system_prompt:
messages = [
{"role": "system", "content": DEFAULT_SYSTEM_PROMPT},
]
user_messages = [
{"role": "User", "content": user_query}
]
else:
user_messages = user_query
if history:
user_messages = history + user_messages
messages += user_messages
# generation prompts
if is_chat:
prompt = generate_chat_prompt(
conversation=messages,
add_generation_prompt=True,
)
else:
prompt = generate_prompt(
user_query=messages,
instruction=instruction,
add_generation_prompt=True,
)
print("--- prompt")
print(prompt)
print("--- output")
# 推論
output = pipeline.generate(
ctx=prompt,
token_count=max_new_tokens,
args=pipeline_args,
callback=print_nolf
)
if is_chat:
user_messages.append(
{"role": "Assistant", "content": output}
)
else:
user_messages += output
end = time.process_time()
##
input_ids = pipeline.encode(prompt)
input_tokens = len(input_ids)
output_ids = pipeline.encode(output)
output_tokens = len(output_ids)
total_time = end - start
tps = output_tokens / total_time
print("\n---")
print(f"prompt tokens = {input_tokens:.7g}")
print(f"output tokens = {output_tokens:.7g} ({tps:f} [tps])")
print(f" total time = {total_time:f} [s]")
return user_messages
以下の通り、コマンドラインで実行します。
CUDA_VISIBLE_DEVICES=0 python -i ./query.py --model-file ./enwiki_10_epoch_3.pth
モデルがロードされpythonのプロンプトが表示されたら、いつものように聞いてみましょう。
>>> history = q("ドラえもんとはなにか")
--- prompt
User: こんにちは
Assistant: わたしは誠実で優秀な日本人のアシスタントです。
User: ドラえもんとはなにか
Assistant:
推論の内容が壊れてないので、マージはきちんとできていそう。
推論の性能は、こちら。マージしたモデルのファイルサイズが28GBほどなので、VRAM(24GB)から溢れていました。このため、推論がとても遅くなっています(秒間0.49トークン)。これは仕方がないですね。。
---
prompt tokens = 49
output tokens = 231 (0.491268 [tps])
total time = 470.211705 [s]
3. 試してみる - りんえもん
とりあえず手順がわかったので、今度は日本語のデータセットである bbz662bbz/databricks-dolly-15k-ja-gozarinnemon で学習してみます。
設定ファイルの修正
設定ファイル sample_config.yaml の修正差分は、以下です。
--- a/RWKV-v6-4bit/sample_config.yaml
+++ b/RWKV-v6-4bit/sample_config.yaml
@@ -84,11 +84,11 @@ trainer:
reload_dataloaders_every_n_epochs: 0
default_root_dir: null
model:
- load_model: rwkv-x060-14b-world-v2.1-26%trained-20240501-ctx4k.pth
+ load_model: model/rwkv-x060-14b-world-v2.1-81%trained-20240527-ctx4k.pth
n_embd: -1
n_layer: -1
vocab_size: -1
- ctx_len: 16384
+ ctx_len: 4096
ctx_len_cutoffs: []
ctx_len_warmup_steps: []
lora_r: 16.0 # LoRA Rank
@@ -121,8 +121,8 @@ model:
torch_set_float32_matmul_precision: high
data:
data_path: dataset/
- source: json
- source_data_dir: dataset_src/
+ source: bbz662bbz/databricks-dolly-15k-ja-gozarinnemon
+ source_data_dir:
source_dataset_params: null
test_split: 0.01
test_split_shuffle: true
学習してみる
enwiki_10kと同じように実行します。
# 学習データと検証用データの作成
python preload_datapath.py sample_config.yaml
# 学習
CUDA_VISIBLE_DEVICES=0 python lightning_trainer.py fit -c sample_config.yaml
epoch 7のチェックポイントファイルとベースモデルをマージします。
# チェックポイントとベースモデルのマージ
CUDA_VISIBLE_DEVICES=1 python merge_lora.py --use-gpu 16 ./model/rwkv-x060-14b-world-v2.1-81%trained-20240527-ctx4k.pth checkpoint/epoch_7_weights.pth databricks-dolly-15k-ja-gozarinnemon_epoch_7.pth
出力するファイル名は databricks-dolly-15k-ja-gozarinnemon_epoch_7.pth です。できあがったファイルは、こちらです。
-rw-r--r-- 1 user user 28138782218 Jun 1 00:10 databricks-dolly-15k-ja-gozarinnemon_epoch_7.pth
以下は、epoch 7までのwandb.aiによるtrain/lossのグラフです。右肩下がりになっていますね。
試してみる
enwiki_10kのセクションで作成(保存)した query.py を用いて、このマージしたモデル databricks-dolly-15k-ja-gozarinnemon_epoch_7.pth を試してみましょう。
CUDA_VISIBLE_DEVICES=0 python -i ./query.py --model-file ./databricks-dolly-15k-ja-gozarinnemon_epoch_7.pth --no-chat
VRAM 24GBは当然のことながら溢れています。
(CPUオフロードするのは仕方ないので)モデルがロードされpythonのプロンプトが表示されたら、いつものように聞いてみましょう。
>>> history = q("ドラえもんとはなにか")
--- prompt
Instruction: None
Input: ドラえもんとはなにか
Output:
「我、りんえもんは思う。」で推論が始まらない…。むぅ。
data.pyを見る限り、設定ファイル内の multi_column_keysの値がnullだと初期値が設定されるようになっていて、その初期設定値はりんえもんのデータセットの書式と合致しているのですが…。
4. 試してみる - ja_conv_wikipedia_orion14B
shi3zさん作成のデータセット shi3z/ja_conv_wikipedia_orion14B_10K を用いて試してみます。
設定ファイルの修正
yamlファイルの設定をデータセットのフォーマットに合わせて変更する必要があります。ポイントは conversation_formatとして 'sender' を利用すること、です。
conversation_format: 'sender'
conversation_key: 'conversations'
conversation_sender_key: 'from'
conversation_sender_value_map: {'human': 'User', 'gpt': 'Assistant', 'system': 'System'}
conversation_input_key_map: {'value': "\n\n{sender}: "}
conversation_sender_suffix: {'human': "", 'gpt': "", 'system': ""}
conversation_sender_mask: {'human': false, 'gpt': true, 'system': false}
conversation_end_of_conversation: "\n\nUser:"
src/data.pyを読みつつ、念のためprintデバックして設定したので、おそらく大丈夫かと。
差分だけですとコピペしにくいので、yamlの全量はこちら。これを sample_config-ja_conv_orion14b_10k.yaml として保存します。
# lightning.pytorch==2.0.5
seed_everything: 2144748874
trainer:
target_batch_size: 32
microbatch_size: 1
num_nodes: 1
devices: auto
accelerator: gpu
strategy: deepspeed_stage_2_offload
precision: bf16
callbacks:
- class_path: custom.SaveElementWeightsCallback # Replace with the actual module path
init_args:
save_dir: checkpoint # Use the YAML's checkpoint directory
elements_to_save:
- lora # Example element to save (replace with actual element names)
- emb # Example element to save (replace with actual element names)
- head # Example element to save (replace with actual element names)
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
name: null
save_dir: .
version: null
offline: false
dir: null
id: null
anonymous: null
project: RWKV6-infctx-LoRA
log_model: false
experiment: null
prefix: ''
checkpoint_name: null
job_type: null
config: null
entity: null
reinit: null
tags:
- RWKV
group: null
notes: null
magic: null
config_exclude_keys: null
config_include_keys: null
mode: null
allow_val_change: null
resume: null
force: null
tensorboard: null
sync_tensorboard: null
monitor_gym: null
save_code: null
settings: null
fast_dev_run: false
max_epochs: 100
min_epochs: null
max_steps: -1
min_steps: null
max_time: null
limit_train_batches: null
limit_val_batches: null
limit_test_batches: null
limit_predict_batches: null
overfit_batches: 0.0
val_check_interval: null
check_val_every_n_epoch: 1
num_sanity_val_steps: 0
log_every_n_steps: null
enable_checkpointing: false
enable_progress_bar: null
enable_model_summary: null
accumulate_grad_batches: 1
gradient_clip_val: null
gradient_clip_algorithm: null
deterministic: null
benchmark: null
inference_mode: true
use_distributed_sampler: false
profiler: null
detect_anomaly: false
barebones: false
plugins: null
sync_batchnorm: false
reload_dataloaders_every_n_epochs: 0
default_root_dir: null
model:
load_model: model/rwkv-x060-14b-world-v2.1-81%trained-20240527-ctx4k.pth
n_embd: -1
n_layer: -1
vocab_size: -1
ctx_len: 4096
ctx_len_cutoffs: []
ctx_len_warmup_steps: []
lora_r: 16.0 # LoRA Rank
lora_alpha: 32.0 #LoRA Alpha it must be double LoRA Rank
lora_dropout: 0.01
lora_quant: true
lora_quant_type: nf4
head_learning: true #if true Enable Full-resolution Head Layer Learning
lr_init: 1.0e-05
lr_final: 1.0e-07
lr_period: -1
lr_period_type: epoch
dropout: 0.0
beta1: 0.9
beta2: 0.99
adam_eps: 1.0e-08
weight_decay: 0.01
warmup_steps: -1
position_loss_bias: 1.0
position_loss_bias_in_validation: false
grad_cp: true
bptt_learning: true
bptt_learning_range: -1
bptt_truncated_learning: false
layerwise_lr: true
dim_att: null
dim_ffn: null
substep_cuda_cache_clear: false
substep_logging: false
torch_set_float32_matmul_precision: high
data:
data_path: dataset.ja_conv_wikipedia_orion14B_10K/
source: shi3z/ja_conv_wikipedia_orion14B_10K
source_data_dir:
source_dataset_params: null
test_split: 0.01
test_split_shuffle: true
text_rechunk_size: 2048
text_rechunk_auto: true
text_rechunk_force: false
tokenizer: world
autoTokenizer: null
world_add_endoftext_token: true
min_token_size: 1
max_token_size: -1
sort_by_length: false
sort_asc: false
training_dataloader_shuffle_auto: true
dataset_offset: -1.0
dataset_length: -1.0
custom_text_key: null
multi_column_keys: null
multi_column_prefix: null
multi_column_suffix: null
multi_column_train_mask: null
multi_column_separator: null
conversation_format: 'sender'
conversation_key: 'conversations'
conversation_input_key_prefix_map: null
conversation_input_key_mask: null
conversation_sender_key: 'from'
conversation_sender_value_map: {'human': 'User', 'gpt': 'Assistant', 'system': 'System'}
conversation_input_key_map: {'value': "\n\n{sender}: "}
conversation_sender_suffix: {'human': "", 'gpt': "", 'system': ""}
conversation_sender_mask: {'human': false, 'gpt': true, 'system': false}
conversation_end_of_conversation: "\n\nUser:"
disable_prompt_completion_mask: false
packing_enable: true
packing_batchsize: 20160
packing_chunksize: 4096
packing_min_ctx_len: -1
packing_in_sequence: false
processing_max_batch_size: 100000
skip_datapath_setup: false
ckpt_path: null
学習してみる
学習用のデータセットがきちんと作成できるかを確認します。
# 学習データと検証用データの作成
python preload_datapath.py sample_config-ja_conv_orion14b_10k.yaml
実行時のログがこちら。なんとなく変換できていそうです。
Downloading readme: 100%|█████████████████████████████████████████████████████████████████| 565/565 [00:00<00:00, 1.75MB/s]
Downloading data: 100%|███████████████████████████████████████████████████████████████| 15.7M/15.7M [00:03<00:00, 4.67MB/s]
Downloading data files: 100%|████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00, 3.37s/it]
Extracting data files: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1596.61it/s]
Setting num_proc from 24 back to 1 for the train split to disable multiprocessing as it only contains one shard.
Generating train split: 12964 examples [00:00, 297009.19 examples/s]
Map (num_proc=24): 100%|███████████████████████████████████████████████████| 12964/12964 [00:00<00:00, 37899.62 examples/s]
Filter (num_proc=24): 100%|████████████████████████████████████████████████| 12964/12964 [00:00<00:00, 41338.87 examples/s]
Map (num_proc=24): 100%|███████████████████████████████████████████████████| 12834/12834 [00:00<00:00, 55711.41 examples/s]
Map (num_proc=24): 100%|███████████████████████████████████████████████████| 12834/12834 [00:00<00:00, 35880.63 examples/s]
Saving the dataset (1/1 shards): 100%|█████████████████████████████████████████| 987/987 [00:00<00:00, 52917.36 examples/s]
Saving the dataset (1/1 shards): 100%|█████████████████████████████████████████| 130/130 [00:00<00:00, 39283.83 examples/s]
元のデータセットをトークン化した学習/検証用のファイルがこちら。
$ find dataset.ja_conv_wikipedia_orion14B_10K/ -type f -ls
22440256 23164 -rw-r--r-- 1 user user 23717472 Jun 1 13:13 dataset.ja_conv_wikipedia_orion14B_10K/train/data-00000-of-00001.arrow
22440258 4 -rw-r--r-- 1 user user 632 Jun 1 13:13 dataset.ja_conv_wikipedia_orion14B_10K/train/dataset_info.json
22440257 4 -rw-r--r-- 1 user user 247 Jun 1 13:13 dataset.ja_conv_wikipedia_orion14B_10K/train/state.json
22440255 4 -rw-r--r-- 1 user user 29 Jun 1 13:13 dataset.ja_conv_wikipedia_orion14B_10K/dataset_dict.json
22440259 240 -rw-r--r-- 1 user user 242944 Jun 1 13:13 dataset.ja_conv_wikipedia_orion14B_10K/test/data-00000-of-00001.arrow
22440261 4 -rw-r--r-- 1 user user 497 Jun 1 13:13 dataset.ja_conv_wikipedia_orion14B_10K/test/dataset_info.json
22440260 4 -rw-r--r-- 1 user user 247 Jun 1 13:13 dataset.ja_conv_wikipedia_orion14B_10K/test/state.json
$
つづいて、学習して、マージです。
# 学習
CUDA_VISIBLE_DEVICES=0 python lightning_trainer.py fit -c sample_config-ja_conv_orion14b_10k.yaml
# チェックポイントとベースモデルのマージ
CUDA_VISIBLE_DEVICES=0 python merge_lora.py --use-gpu 16 ./model/rwkv-x060-14b-world-v2.1-81%trained-20240527-ctx4k.pth checkpoint/epoch_1_weights.pth ja_conv_wikipedia_orion14B_10K_epoch_1.pth
試してみる
pythonコマンドを起動して、
CUDA_VISIBLE_DEVICES=0 python -i ./query.py --model-file ./ja_conv_wikipedia_orion14B_10K_epoch_1.pth
いくつか(100個ほど)聞いてみたのですが、
>>> history = q("「恋するユニバース」はいつリリースされましたか?")
違うんですよね…。うーん。
5. まとめ
VRAM 24GBあれば確かに学習プロセスは完了しました。ただ、推論結果を見ると、学習できているのかがなんとも言えず悶々しています…。
ただ、マージしたモデルの(ファイル)サイズが28GBほどあるので、24GBのVRAMだと推論時に溢れます。
この記事が気に入ったらサポートをしてみませんか?