japanese-stableLM-alphaのqloraにおけるtarget_modulesの設定
最近、Llama2やStablelmなど日本語の性能がある程度見込めそうなモデルが多く出てきており、qloraを用いたファインチューニングを行う機会が増えてきています。
今後、新しいモデルをqloraファインチューニングする際に、自身で適切な設定ができるようになりたく、StableLMが公開されたことを機に調べてみました。
1. qloraについて
qloraは、
・軽量なファインチューニング手法であるlora
・量子化
を合わせることでより少ない計算リソースでLLMをチューニングできる様に提案された手法。
HuggingFace関連のライブラリや、qloraの論文で使用されたスクリプトが用意されているのでこれらを使って実施することができます。
2. 結論
上記、npaka先生の記事を見ていただければわかる通り、qloraにおいて以下のようなLoraConfigの設定が必要となります。
結論として、target_modulesは以下の通りだと思われます。
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=8,
bias="none",
task_type="CAUSAL_LM",
target_modules=[
"query_key_value",
"dense",
"packed_input_proj",
"out_proj",
]
)
3. 調べ方①
qloraの論文を読むと以下の記載がある通り、qloraにおけるloraターゲットモジュールとしては全てのlinear block layerを対象にする必要がある様です。
下記でモデルを読み込んだのち、
import torch
from transformers import AutoModelForCausalLM
model_name = "stabilityai/japanese-stablelm-base-alpha-7b"
model = AutoModelForCausalLM.from_pretrained(
model_name,
load_in_4bit=True,
torch_dtype=torch.float16,
trust_remote_code=True
)
以下でモデルのアーキテクチャを調べることができます。
model
結果は以下の通りです。Transformer内のLinear層はquery_key_value、dense、packed_input_proj、out_projであり、これらがtarget_modulesとなります。
JapaneseStableLMAlphaForCausalLM(
(transformer): JapaneseStableLMAlphaModel(
(embed_in): Embedding(65536, 4096)
(layers): ModuleList(
(0-31): 32 x DecoderLayer(
(input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=False)
(post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
(attention): Attention(
(rotary_emb): RotaryEmbedding()
(query_key_value): Linear4bit(in_features=4096, out_features=12288, bias=False)
(dense): Linear4bit(in_features=4096, out_features=4096, bias=False)
)
(mlp): MLP(
(packed_input_proj): Linear4bit(in_features=4096, out_features=22016, bias=False)
(out_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
(act): SiLU()
)
)
)
(final_layer_norm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
)
(embed_out): Linear(in_features=4096, out_features=65536, bias=False)
)
ちなみにloraを適用した後のアーキテクチャは以下の様になります。
PeftModelForCausalLM(
(base_model): LoraModel(
(model): JapaneseStableLMAlphaForCausalLM(
(transformer): JapaneseStableLMAlphaModel(
(embed_in): Embedding(65536, 4096)
(layers): ModuleList(
(0-31): 32 x DecoderLayer(
(input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=False)
(post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
(attention): Attention(
(rotary_emb): RotaryEmbedding()
(query_key_value): Linear4bit(
in_features=4096, out_features=12288, bias=False
(lora_dropout): ModuleDict(
(default): Dropout(p=0.1, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=4096, out_features=8, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=8, out_features=12288, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
)
(dense): Linear4bit(
in_features=4096, out_features=4096, bias=False
(lora_dropout): ModuleDict(
(default): Dropout(p=0.1, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=4096, out_features=8, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=8, out_features=4096, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
)
)
(mlp): MLP(
(packed_input_proj): Linear4bit(
in_features=4096, out_features=22016, bias=False
(lora_dropout): ModuleDict(
(default): Dropout(p=0.1, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=4096, out_features=8, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=8, out_features=22016, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
)
(out_proj): Linear4bit(
in_features=11008, out_features=4096, bias=False
(lora_dropout): ModuleDict(
(default): Dropout(p=0.1, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=11008, out_features=8, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=8, out_features=4096, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
)
(act): SiLU()
)
)
)
(final_layer_norm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
)
(embed_out): Linear(in_features=4096, out_features=65536, bias=False)
)
)
)
4. 調べ方②
これまでの記事をゆっくり書いていたら、npaka先生が上記を調べるためのコードを記事に書いてくださっていました。
発想は同じでLinear層をアーキテクチャから検索して表示しているようです。
import bitsandbytes as bnb
def find_all_linear_names(model):
cls = bnb.nn.Linear4bit # (default:torch.nn.Linear,4bit:bnb.nn.Linear4bit,8bit:bnb.nn.Linear8bitLt)
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if 'lm_head' in lora_module_names: # needed for 16-bit
lora_module_names.remove('lm_head')
return list(lora_module_names)
find_all_linear_names(model)
これでstablelmを調べると確かにquery_key_value、dense、packed_input_proj、out_projとなりました。
['out_proj', 'dense', 'packed_input_proj', 'query_key_value']
目で調べるよりもこちらを使った方が確実で良いですね。
この記事が気に入ったらサポートをしてみませんか?