ABCIでMPT-7Bのファインチューニングを試す
前提知識
MPT-7Bは最近発表された商用利用可能な大規模言語モデルで、LLaMAに匹敵する性能を持っていると言われています。
ABCIは経産省が管轄する日本在住者なら誰でも安価に使えるスーパーコンピュータです。
(ただし登録がいろいろ大変なので法人が前提です/利用料は最低20万円から)
対象読者
行間が読める人。本文が間違っていても自分でソースコードに手を加えて修正できるスキルがある人。ABCIを使えるポジションの人。
僕も人間なのでミスはよくありますし、備忘録とこれからやろうとする人のために書いています。質問は受け付けません(自分でなんとかしてください)。
準備
思ったより大変だったのでメモ
まず、大前提として自宅のA6000x2のマシンでできるかと思ったら、ダメだった(12:57更新。ウソ:A6000x2でちゃんとできました)。
まず、MPTはTransformerなのでRWKVと違い、VRAMをめちゃくちゃ要求します。必要なVRAMの容量は、12*N(Nはパラメータ数)で概算できます。
たとえばGPT-13Bをやりたければ、12*13=156GBが必要ということになります。
12*7=84GBなので、もしかすると自宅のドスパラ製Memeplexサーバ
A6000x2(48GBx2=96GB)でもできるのかもしれない(12:57更新:できた)けど、とりあえず面倒だから確実に学習できるABCIで練習しました。
手順としてはまず、llm-foundryのリポジトリをgit cloneします。
$ git clone https://github.com/mosaicml/llm-foundry.git
$ cd llm-foundry
ABCIの場合は、ここですぐセットアップできません。
まず、moduleをロードします。
MPT-7Bが動作するモジュールの組み合わせは、python 3.10 / cuda 11.7.1 / cudnn 8.4.1です。はっきりいってこの情報だけでもメモっておきたいのでこのエントリを書いてます
$ module load python/3.10/3.10.10
$ module load cuda/11.7/11.7.1
$ module load cudnn/8.4/8.4.1
ここで注意しなければならないのは、venvを使う場合は、venvを設定した後でモジュールを読み込む必要があることです。間違うとパスの順番の関係でPythonが動かなくなります。
ここでようやくpip install を走らせることができます。
$ pip install -e ".[gpu]"
ここまでインタラクティブノード上でできるはずですが、エラーが出たら自分でなんとかしてください。
ここを突破できないスキルの人はこの先はもっと難しいと思います。
データセットの変換
まず最初は大人しくサンプルのページにある通りに動くかやってみましょう。ここで正常に動かなかったらセットアップのやり直しです。
僕はファインチューニングが試したかったので最初から用意されているサンプルではダメでした。
ファインチューニングをするには、まずhttps://github.com/mosaicml/llm-foundry/blob/main/scripts/train/yamls/mpt/finetune/7b_dolly_sft.yamlにある設定ファイルを使います。
そのままでも動くかもしれないけど、これだけでは芸がないので、dolly_15kを日本語化したhttps://huggingface.co/datasets/kunishou/databricks-dolly-15k-jaに変更してみます。
# Dataloaders
train_loader:
name: finetuning
dataset:
hf_name: kunishou/databricks-dolly-15k-ja
また、modelも、エラーが出る場合は、損失関数をtorch_crossentropyにするととりあえず動く
# Model
model:
name: mpt_causal_lm
init_device: meta
d_model: 4096
n_heads: 32
n_layers: 32
expansion_ratio: 4
max_seq_len: ${max_seq_len}
vocab_size: 50368
attn_config:
attn_impl: triton
loss_fn: torch_crossentropy
あと、最後の行、なぜかネットからデータ読み込むことになってる(雑だ)ので、コメントアウトしておく。完全なyamlは巻末に置いておきます。
#load_path: oci://my-bucket/my-folder/mpt-7b/checkpoints/some_checkpoint.pt
このまま実行するとエラーが起きるので、
vi /home/自分のABCIユーザー名/.local/lib/python3.10/site-packages/llmfoundry/data/finetuning/tasks.py
で以下のブロックを追加
@dataset_constructor.register('kunishou/databricks-dolly-15k-ja')
def dolly_preprocessing_function(inp: Dict):
"""Format the text string."""
PROMPT_FORMAT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n'
try:
if inp['input'] != '':
instruction = inp['instruction'] + '\n' + inp['input']
else:
instruction = inp['instruction']
prompt = PROMPT_FORMAT.format(instruction=instruction)
response = inp['output']
except Exception as e:
raise ValueError(
f'Unable to extract prompt/response from {inp=}') from e
return {'prompt': prompt, 'response': response}
インストールしたパッケージを弄るのはちょっと乱暴すぎるのでもう少しマシな方法が公式に説明されている。
が、面倒が増えそうなので無視した。ローカルのデータセットを学習するときにはこの方法を使う方が効率が良さそうだがとりあえず動くところを目指す。
$ cd scripts
$ composer train/train.py train/yamls/mpt/finetune/7b_dolly_sft.yaml
さあこれでうまくいけば学習が開始される。
ちなみにデフォルトのyamlだと1ep(1epoch)しか学習しない設定なので適宜ほしいエポック数に変えること。
また、yamlの末尾にセーブするディレクトリなどを設定する。
自分のユーザーディレクトリに保存すると、すぐにクォータがいっぱいになって死ぬので、保存先は必ずスクラッチパッドを使うこと
[epoch=1][batch=213/235]:
Train time/batch: 212
Train time/sample: 13457
Train time/batch_in_epoch: 212
Train time/sample_in_epoch: 13457
Train time/token: 27559936
Train time/token_in_epoch: 27559936
Train memory/allocated_mem: 17.6580
Train memory/active_mem: 17.6580
Train memory/inactive_mem: 2.2777
Train memory/reserved_mem: 36.8890
Train memory/alloc_retries: 2
Train trainer/device_train_microbatch_size: 8
Train loss/train/total: 4.3715
Train metrics/train/LanguageCrossEntropy: 4.3677
Train metrics/train/LanguagePerplexity: 78.8618
Train throughput/batches_per_sec: 0.2320
Train throughput/samples_per_sec: 14.7099
Train throughput/device/batches_per_sec: 0.0290
Train throughput/device/samples_per_sec: 1.8387
Train throughput/flops_per_sec: 1300668282604287.2500
Train throughput/device/flops_per_sec: 162583535325535.9062
Train throughput/device/mfu: 0.5211
Train time/train: 0.2841
Train time/val: 0.0000
Train time/total: 0.2841
Train lr-DecoupledAdamW/group0: 0.0000
Train time/remaining_estimate: 30.9840
さて、学習した。
とりあえず1epoch。こんなので何も変わらないと思うが学習できたということが大事だ。
推論
学習したら推論である。
しかし、MPT-7Bは、そう簡単に推論できない。
まず、MPT-7Bで推論できるようにするために、学習したチェックポイントを変換する必要がある。
$ python3 inference/convert_composer_to_hf.py --composer_path dolly15j/checkpoints/ep0-ba200-rank0.pt --hf_output_path /scratch/自分のABCIID/out --output_precision bf16
これで/scratch/ABCIユーザーID/outに変換された。
いよいよ推論だ。
$ python3 inference/hf_generate.py --name_or_path /scratch/aca10054zv/out --prompts "光の三原色"
Loading HF Config...
Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Loading HF model to device=cuda:0 and dtype=torch.bfloat16...
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.
You are using config.init_device='cpu', but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.
n_params=6658859008
Loading HF tokenizer...
Generate kwargs:
{'max_new_tokens': 100, 'temperature': 1.0, 'top_p': 1.0, 'top_k': 50, 'use_cache': True, 'do_sample': True, 'eos_token_id': 0, 'pad_token_id': 0}
Tokenizing prompts...
NOT using autocast...
Warming up...
Generating responses...
####################################################################################################
光の三原色、キ��はジ、フラシニアンレイロ
バリ�、デリ
バミ・ママ・アッドレはマカ(マは、デ、ラのルリ()年
ビ年
オント、ロメラーラーのカリ:パレホである(セフ�とベスメン」はあ月月
####################################################################################################
bs=1, input_tokens=array([6]), output_tokens=array([89])
total_input_tokens=6, total_output_tokens=89
encode_latency=73.43ms, gen_latency=1981.96ms, decode_latency=41.22ms, total_latency=2096.61ms
latency_per_output_token=23.56ms/tok
output_tok_per_sec=42.45tok/sec
やはり1エポックではなにもできないというか却って悪くなってる気さえする。
まあ本当はresponse形式にしなきゃなんないのかもしれないけど。
とにかく学習のようなものが回せて、推論できた。
あとはデータをどうするかハイパラをどうするか考えるだけだ。
疲れた。朝5時から取り掛かっていろいろハマって昼過ぎになってしまった。
おまけ:完全な設定ファイル(yaml)
個人のディレクトリ的なところだけ隠した完全な設定ファイルを置いておきます。参考まで
max_seq_len: 2048
global_seed: 17
# Run Name
run_name: dolly15j
# If left blank, will be read from env var $C/baOMPOSER_RUN_NAME
# Model
model:
name: mpt_causal_lm
init_device: meta
d_model: 4096
n_heads: 32
n_layers: 32
expansion_ratio: 4
max_seq_len: ${max_seq_len}
vocab_size: 50368
attn_config:
attn_impl: triton
loss_fn: torch_crossentropy
# Tokenizer
tokenizer:
name: EleutherAI/gpt-neox-20b
kwargs:
model_max_length: ${max_seq_len}
# Dataloaders
train_loader:
name: finetuning
dataset:
hf_name: kunishou/databricks-dolly-15k-ja
split: train
max_seq_len: ${max_seq_len}
allow_pad_trimming: false
decoder_only_format: true
shuffle: true
# Use `python llmfoundry/data/packing.py --yaml-path /path/to/this/yaml/ ...` to profile
# this run's optimal packing_ratio
# packing_ratio:
drop_last: true
num_workers: 8
pin_memory: false
prefetch_factor: 2
persistent_workers: true
timeout: 0
# There is no validation split so we skip eval_loader
# Optimization
scheduler:
name: linear_decay_with_warmup # linear no warmup is HF default which dolly used
t_warmup: 0ba
alpha_f: 0
optimizer:
# mimic HF defaults to replicate dolly
name: decoupled_adamw
lr: 1.0e-5
betas:
- 0.9
- 0.999
eps: 1.0e-8
weight_decay: 0
algorithms:
gradient_clipping:
clipping_type: norm
clipping_threshold: 1.0
max_duration: 100ep
eval_interval: 1 # this is the only allowed value for no eval
# eval_first: false
# eval_subset_num_batches: -1
global_train_batch_size: 64 # assuming 8 gpus
# System
seed: ${global_seed}
device_eval_batch_size: 8
device_train_microbatch_size: 8
# device_train_microbatch_size: auto
precision: amp_bf16
# FSDP
fsdp_config:
sharding_strategy: FULL_SHARD
mixed_precision: PURE
activation_checkpointing: true
activation_checkpointing_reentrant: false
activation_cpu_offload: false
limit_all_gathers: true
verbose: false
# Logging
progress_bar: false
log_to_console: true
console_log_interval: 1ba
callbacks:
speed_monitor:
window_size: 10
lr_monitor: {}
memory_monitor: {}
runtime_estimator: {}
# loggers:
# wandb: {}
# Checkpoint to local filesystem or remote object store
save_interval: 200ba
# save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK
save_folder: /scratch/あなたのABCIユーザー名/{run_name}/checkpoints
# save_folder: s3://my-bucket/my-folder/{run_name}/checkpoints
# Load from remote object store
# REPLACE THE BELOW with you own checkpoint!
#load_path: oci://my-bucket/my-folder/mpt-7b/checkpoints/some_checkpoint.pt