
re: WSL2でMedusaを試してみる
こちらの記事は以下の記事の続きで、
Xに書いた次の件(stage2だと速い!)をまとめた記事です。
Elyza-japanese-Llama-2-7b-instruct、データセットは @shi3z さん作成の一番小さい https://t.co/mlf5aqnTZI を使用。 stage1,2で生成されたMedusaの頭を使用し、RTX 4090(24GB)で推論したtokens/sはこちら。
— NOGUCHI, Shoji (@noguchis) January 30, 2024
stage1: 33, 39, 36, 37, 37
stage2: 45, 52, 52, 51, 53
stage2速い。noteにまとめねば
(注)まとめに際して、使用するデータセットを shi3z/Japanese_Wikipedia_Conversationから、shi3z/ja_conv_wikipedia_orion14B_100Kに変更しています。
1. 学習の前に
使用するモデル
ベースとするモデルは、elyza/ELYZA-japanese-Llama-2-7b-instruct です。ありがとうございます!
使用するデータセット
MedusaのREADMEにあったShareGPT_Vicuna_unfilteredをデータセットとして学習すると、日本語能力が奪われてしまいました。システムプロンプトに何を与えてもすべて英語で回答してしまいます。そりゃそうか…。
「これはとても困った」ので、shi3z さんが公開されている日本語マルチターンデータセット(10万会話)を入力にして試します。貴重なデータセットの公開、ありがとうございます!
ダウンロード
データセットのダウンロードです。
git clone https://huggingface.co/datasets/shi3z/ja_conv_wikipedia_orion14B_10K
設定ファイル
stage1はこちら。datasets.pathを適切に修正します。num_epochsは1にしています。
base_model: elyza/ELYZA-japanese-Llama-2-7b-instruct
base_model_config: elyza/ELYZA-japanese-Llama-2-7b-instruct
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: ja_conv_wikipedia_orion14B_100K/ja_conv_orion14b_100K.jsonl
type: sharegpt
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage1
adapter: qlora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
- lm_head
lora_target_linear:
lora_fan_in_fan_out:
lora_modules_to_save:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0005
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 40
eval_steps: 40
save_steps:
save_total_limit: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
medusa_num_heads: 5
medusa_num_layers: 1
medusa_heads_coefficient: 0.2
medusa_decay_coefficient: 0.8
medusa_logging: true
medusa_scheduler: constant
medusa_lr_multiplier: 4.0
medusa_only_heads: true
ddp_find_unused_parameters: true
# Stage 1: only train the medusa heads
# Stage 2: train the whole model
stage2は、こちら。
base_model: elyza/ELYZA-japanese-Llama-2-7b-instruct
base_model_config: elyza/ELYZA-japanese-Llama-2-7b-instruct
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: ja_conv_wikipedia_orion14B_100K/ja_conv_orion14b_100K.jsonl
type: sharegpt
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage2
adapter: qlora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
- lm_head
lora_target_linear:
lora_fan_in_fan_out:
lora_modules_to_save:
lora_model_dir: ./Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage1
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0005
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 40
eval_steps: 40
save_steps:
save_total_limit: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
medusa_num_heads: 5
medusa_num_layers: 1
medusa_heads_coefficient: 0.2
medusa_decay_coefficient: 0.8
medusa_logging: true
medusa_scheduler: constant
medusa_lr_multiplier: 4.0
# medusa_only_heads: true
# ddp_find_unused_parameters: true
# Stage 1: only train the medusa heads
# Stage 2: train the whole model
参考までに、stage1とstage2の差分はこちらです。
$ diff -u axolotl/examples/medusa/elyza_7b_qlora_stage[12]-01.yml
--- axolotl/examples/medusa/elyza_7b_qlora_stage1-01.yml 2024-01-30 11:20:20.591707705 +0900
+++ axolotl/examples/medusa/elyza_7b_qlora_stage2-01.yml 2024-01-30 11:20:38.939502607 +0900
@@ -13,7 +13,7 @@
type: sharegpt
dataset_prepared_path:
val_set_size: 0.01
-output_dir: ./Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage1
+output_dir: ./Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage2
adapter: qlora
lora_model_dir:
@@ -33,6 +33,7 @@
lora_target_linear:
lora_fan_in_fan_out:
lora_modules_to_save:
+lora_model_dir: ./Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage1
sequence_len: 4096
sample_packing: true
@@ -86,7 +87,7 @@
medusa_logging: true
medusa_scheduler: constant
medusa_lr_multiplier: 4.0
-medusa_only_heads: true
-ddp_find_unused_parameters: true
+# medusa_only_heads: true
+# ddp_find_unused_parameters: true
# Stage 1: only train the medusa heads
# Stage 2: train the whole model
$
2. 学習
stage2に進むためには、stage1の学習結果が必要となるようですので、順番に実行していきます。
途中でターミナルへの接続がタイムアウトしてプロセスがkillされたらとても悲しいので、nohupをかましてstdoutの内容はlogファイルに書き出すようにしています。
# stage1
CUDA_VISIBLE_DEVICES=0 nohup accelerate launch -m axolotl.cli.train \
./axolotl/examples/medusa/elyza_7b_qlora_stage1.yml \
> ./Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage1.log &
# stage2
CUDA_VISIBLE_DEVICES=0 nohup accelerate launch -m axolotl.cli.train \
./axolotl/examples/medusa/elyza_7b_qlora_stage2.yml \
> ./Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage2.log &
そんなこんなでウン十時間経過。はい、できました。
Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage1:
-rw-rw-r-- 1 user user 812 1月 30 14:17 adapter_config.json
-rw-rw-r-- 1 user user 2062971154 1月 30 14:17 adapter_model.bin
-rw-rw-r-- 1 user user 1132 1月 30 11:21 config.json
-rw-rw-r-- 1 user user 2545 1月 30 14:17 README.md
-rw-rw-r-- 1 user user 551 1月 30 11:21 special_tokens_map.json
-rw-rw-r-- 1 user user 1011 1月 30 11:21 tokenizer_config.json
-rw-rw-r-- 1 user user 499723 1月 30 11:21 tokenizer.model
Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage2:
-rw-rw-r-- 1 user user 812 1月 30 20:38 adapter_config.json
-rw-rw-r-- 1 user user 2062971154 1月 30 20:38 adapter_model.bin
-rw-rw-r-- 1 user user 1132 1月 30 14:26 config.json
-rw-rw-r-- 1 user user 2545 1月 30 20:38 README.md
-rw-rw-r-- 1 user user 551 1月 30 14:26 special_tokens_map.json
-rw-rw-r-- 1 user user 1011 1月 30 14:26 tokenizer_config.json
-rw-rw-r-- 1 user user 499723 1月 30 14:26 tokenizer.model
3. 試してみる
stage1を試す
まずは、stage1から。
CUDA_VISIBLE_DEVICES=0 python -m medusa.inference.cli \
--model ./Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage1 \
--conv-system-msg "あなたは誠実で優秀な日本人のアシスタントです。" \
--max-steps 256
5回聞きます。平均 秒あたり 39.0 トークンでした。
※「 !!reset 」は会話履歴をクリアする、「 !!exit 」は処理終了のそれぞれコマンドです。
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [01:57<00:00, 58.64s/it]
Some weights of MedusaModelLlama were not initialized from the model checkpoint at elyza/ELYZA-japanese-Llama-2-7b-instruct and are newly initialized: ['medusa_head.3.1.weight', 'medusa_head.2.1.weight', 'medusa_head.1.1.weight', 'medusa_head.2.0.linear.bias', 'medusa_head.4.0.linear.weight', 'medusa_head.2.0.linear.weight', 'medusa_head.3.0.linear.bias', 'medusa_head.1.0.linear.weight', 'medusa_head.1.0.linear.bias', 'medusa_head.0.0.linear.bias', 'medusa_head.0.0.linear.weight', 'medusa_head.3.0.linear.weight', 'medusa_head.0.1.weight', 'medusa_head.4.0.linear.bias', 'medusa_head.4.1.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[INST]: ドラえもんとはなにか
[/INST]: /mnt/data/shoji_noguchi/venv/medusa-llm/Medusa/medusa/model/medusa_model.py:232: UserWarning: Please specify medusa choice configuration!
warnings.warn('Please specify medusa choice configuration!')
承知しました。
ドラえもんは、藤子・F・不二雄によって生み出された漫画の主人公です。
ドラえもんは、小学生の男の子で、大のおっかけであるのが特徴です。彼は、「ひみつ道具」を持っていて、様々な場面で役に立つ道具を使います。彼の生活の中心は、のび太とのび太の家族である、のび太 の父、のび太の母、ジャイアン、スネ夫、しずか、どこでもドアの8人です。
ドラえもんは、1970年から始まったテレビアニメでも人気が高く、様々なメディアでの展開が行われています。れ
---
prompt tokens = 19
output tokens = 251 (35.315387 [tps])
total time = 7.107384 [s]
[INST]: !!reset
resetting...
[INST]: ドラえもんとはなにか
[/INST]: 承知しました。
ドラえもんは、藤子・F・不二雄によって生み出された漫画の主人公です。
ドラえもんは、小学生の男の子で、大のおっかけであるのが特徴です。彼は、「ひみつ道具」を持っていて、様々な場面で役に立つ道具を使います。彼の生活の中心は、のび太とのび太の家族である、のび太 の父、のび太の母、ジャイアン、スネ夫、しずか、どこでもドアの8人です。
ドラえもんは、1970年から始まったテレビアニメでも人気が高く、様々なメディアでの展開が行われています。れ
---
prompt tokens = 19
output tokens = 251 (39.240447 [tps])
total time = 6.396461 [s]
[INST]: !!reset
resetting...
[INST]: ドラえもんとはなにか
[/INST]: 承知しました。
ドラえもんは、藤子・F・不二雄によって生み出された漫画の主人公です。
ドラえもんは、小学生の男の子で、大のおっかけであるのが特徴です。彼は、「ひみつ道具」を持っていて、様々な場面で役に立つ道具を使います。彼の生活の中心は、のび太とのび太の家族である、のび太 の父、のび太の母、ジャイアン、スネ夫、しずか、どこでもドアの8人です。
ドラえもんは、1970年から始まったテレビアニメでも人気が高く、様々なメディアでの展開が行われています。れ
---
prompt tokens = 19
output tokens = 251 (40.192272 [tps])
total time = 6.244982 [s]
[INST]: !!reset
resetting...
[INST]: ドラえもんとはなにか
[/INST]: 承知しました。
ドラえもんは、藤子・F・不二雄によって生み出された漫画の主人公です。
ドラえもんは、小学生の男の子で、大のおっかけであるのが特徴です。彼は、「ひみつ道具」を持っていて、様々な場面で役に立つ道具を使います。彼の生活の中心は、のび太とのび太の家族である、のび太 の父、のび太の母、ジャイアン、スネ夫、しずか、どこでもドアの8人です。
ドラえもんは、1970年から始まったテレビアニメでも人気が高く、様々なメディアでの展開が行われています。れ
---
prompt tokens = 19
output tokens = 251 (40.316779 [tps])
total time = 6.225696 [s]
[INST]: !!reset
resetting...
[INST]: ドラえもんとはなにか
[/INST]: 承知しました。
ドラえもんは、藤子・F・不二雄によって生み出された漫画の主人公です。
ドラえもんは、小学生の男の子で、大のおっかけであるのが特徴です。彼は、「ひみつ道具」を持っていて、様々な場面で役に立つ道具を使います。彼の生活の中心は、のび太とのび太の家族である、のび太 の父、のび太の母、ジャイアン、スネ夫、しずか、どこでもドアの8人です。
ドラえもんは、1970年から始まったテレビアニメでも人気が高く、様々なメディアでの展開が行われています。れ
---
prompt tokens = 19
output tokens = 251 (40.712306 [tps])
total time = 6.165212 [s]
[INST]: !!exit
exit...
stage2を試す
続いて、stage2。こちらも5回聞きます。
CUDA_VISIBLE_DEVICES=0 python -m medusa.inference.cli \
--model ./Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage2 \
--conv-system-msg "あなたは誠実で優秀な日本人のアシスタントです。" \
--max-steps 256
5回聞きます。平均 秒あたり 50.9 トークンでした。
ただ、推論結果がstage1よりも短くなる傾向にあるのですよね。
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████| 2/2 [00:31<00:00, 15.73s/it]
Some weights of MedusaModelLlama were not initialized from the model checkpoint at elyza/ELYZA-japanese-Llama-2-7b-instruct and are newly initialized: ['medusa_head.1.0.linear.weight', 'medusa_head.0.1.weight', 'medusa_head.0.0.linear.weight', 'medusa_head.3.0.linear.bias', 'medusa_head.2.0.linear.bias', 'medusa_head.2.0.linear.weight', 'medusa_head.1.0.linear.bias', 'medusa_head.1.1.weight', 'medusa_head.4.1.weight', 'medusa_head.3.1.weight', 'medusa_head.4.0.linear.bias', 'medusa_head.3.0.linear.weight', 'medusa_head.2.1.weight', 'medusa_head.0.0.linear.bias', 'medusa_head.4.0.linear.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[INST]: ドラえもんとはなにか
[/INST]: /mnt/data/shoji_noguchi/venv/medusa-llm/Medusa/medusa/model/medusa_model.py:232: UserWarning: Please specify medusa choice configuration!
warnings.warn('Please specify medusa choice configuration!')
ドラえもんは、藤子・F・不二雄の漫画作品であり、日本の漫画家である藤子・F・不二雄の代表作です。
1970年から1977年まで週刊少年サンデーに連載され、1979年から1996年までテレビアニメ化されました。また、1990年代からは映画化も されています。9
---
prompt tokens = 19
output tokens = 142 (45.777373 [tps])
total time = 3.101969 [s]
[INST]: !!reset
resetting...
[INST]: ドラえもんとはなにか
[/INST]: ドラえもんは、藤子・F・不二雄の漫画作品であり、日本の漫画家である藤子・F・不二雄の代表作です。
1970年から1977年まで週刊少年サンデーに連載され、1979年から1980年まではテレビアニメ化されました。
---
prompt tokens = 19
output tokens = 117 (52.974490 [tps])
total time = 2.208610 [s]
[INST]: !!reset
resetting...
[INST]: ドラえもんとはなにか
[/INST]: ドラえもんは、藤子・F・不二雄の漫画作品であり、日本の漫画家である藤子・F・不二雄の代表作です。
1970年から1977年まで週刊少年サンデーに連載され、1979年から1980年まではテレビアニメ化されました。
---
prompt tokens = 19
output tokens = 117 (52.957893 [tps])
total time = 2.209302 [s]
[INST]: !!reset
resetting...
[INST]: ドラえもんとはなにか
[/INST]: ドラえもんは、藤子・F・不二雄の漫画作品であり、日本の漫画家である藤子・F・不二雄の代表作です。
1970年から1977年まで週刊少年サンデーに連載され、1979年から1980年まではテレビアニメ化されました。
---
prompt tokens = 19
output tokens = 117 (51.590991 [tps])
total time = 2.267838 [s]
[INST]: !!reset
resetting...
[INST]: ドラえもんとはなにか
[/INST]: ドラえもんは、藤子・F・不二雄の漫画作品であり、日本の漫画家である藤子・F・不二雄の代表作です。
1970年から1977年まで週刊少年サンデーに連載され、1979年から1980年まではテレビアニメ化されました。
---
prompt tokens = 19
output tokens = 117 (53.476681 [tps])
total time = 2.187869 [s]
[INST]: !!exit
exit...
4. まとめ
これまでの弊環境における 秒あたりトークン数をまとめると以下です。
# Medusa stage1 by Legacy: datasets = Aeala/ShareGPT_Vicuna_unfiltered
29.5
# Medusa stage1 by axolotl: datasets = Aeala/ShareGPT_Vicuna_unfiltered
23.4
# Medusa stage1 by axolotl: datasets = shi3z/ja_conv_wikipedia_orion14B_100K
39.0
# Medusa stage2 by axolotl: datasets = shi3z/ja_conv_wikipedia_orion14B_100K
50.9
# transfomers
11.9
# vLLM
54.1
データセットの違いで速度が変わるように見えます。頭が混乱しないからかしら。
vLLMより遅いように見えますが、普通に使用していると秒あたり 55とか59 トークンとvLLMを超えるときもあります。
[INST]: しずかちゃんについて詳しく教えてください
[/INST]: しずかちゃんは、ドラえもんの妹であり、藤子不二雄の漫画作品『ドラえもん』に登場する架空の人物です。彼女は小学生の女の子で、ドラえもんの妹であり、彼女の妹であるしずかちゃんも登場します。彼女は兄のドラえもんと同じく、冷凍庫に入っていることがあります。し
---
prompt tokens = 188
output tokens = 151 (55.873504 [tps])
total time = 2.702533 [s]
[INST]: もっと詳しく教えてください
[/INST]: もっと詳しく教えると、しずかちゃんはドラえもんの妹であり、彼女は小学生の女の子です。彼女はドラえもんと同じく、冷凍庫に入っていることがあります。彼女はドラえもんの妹であり、彼女は小学生の女の子です。
---
prompt tokens = 364
output tokens = 121 (59.690177 [tps])
total time = 2.027134 [s]
だだ、全体的に問いに対する回答が微妙なのですよね…。
追記 - 2024/2/2
nums_epoch:2 の結果
Medusaのstage1, 2の学習をepoch 2で回してlossを減らしたら、変わるかしら?と思い、試しました。
確かに、epoch 1と比較してlossは僅かですが小さくなりました。
・stage1: 1.9101 ⇒ 1.8838
・stage2: 2.3897 ⇒ 2.3564
ただ、推論結果の傾向は大きくは変わらずでした。
・stage1: ベースモデルに近しい内容
・stage2: 箇条書き、同一単語の利用も多くなり、応答が短くなる傾向あり
先読みが影響しているのですかね。
生成したMedusaの頭
stage1とstage2とも、Hugging Faceにアップロードしています。ぜひ、ご賞味ください。
・stage1
・stage2
P.S.
サムネイルはCopilotに「ElyzaとMedusaと聞いて想像するイメージ」として作成されたものです。