見出し画像

japanese-stableLM-alphaのqloraにおけるtarget_modulesの設定

最近、Llama2やStablelmなど日本語の性能がある程度見込めそうなモデルが多く出てきており、qloraを用いたファインチューニングを行う機会が増えてきています。
今後、新しいモデルをqloraファインチューニングする際に、自身で適切な設定ができるようになりたく、StableLMが公開されたことを機に調べてみました。


1. qloraについて

qloraは、
・軽量なファインチューニング手法であるlora
・量子化
を合わせることでより少ない計算リソースでLLMをチューニングできる様に提案された手法。
HuggingFace関連のライブラリや、qloraの論文で使用されたスクリプトが用意されているのでこれらを使って実施することができます。

2. 結論

上記、npaka先生の記事を見ていただければわかる通り、qloraにおいて以下のようなLoraConfigの設定が必要となります。
結論として、target_modulesは以下の通りだと思われます。

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=8,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "query_key_value",
        "dense",
        "packed_input_proj",
        "out_proj",
    ]
)

3. 調べ方①

qloraの論文を読むと以下の記載がある通り、qloraにおけるloraターゲットモジュールとしては全てのlinear block layerを対象にする必要がある様です。

we find that the most critical LoRA hyperparameter is how many LoRA adapters are used in total and that LoRA on all linear transformer block layers are required to match full finetuning performance.

https://arxiv.org/pdf/2305.14314.pdf

下記でモデルを読み込んだのち、

import torch
from transformers import AutoModelForCausalLM

model_name = "stabilityai/japanese-stablelm-base-alpha-7b"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_4bit=True,
    torch_dtype=torch.float16,
    trust_remote_code=True
)

以下でモデルのアーキテクチャを調べることができます。

model

結果は以下の通りです。Transformer内のLinear層はquery_key_value、dense、packed_input_proj、out_projであり、これらがtarget_modulesとなります。

JapaneseStableLMAlphaForCausalLM(
  (transformer): JapaneseStableLMAlphaModel(
    (embed_in): Embedding(65536, 4096)
    (layers): ModuleList(
      (0-31): 32 x DecoderLayer(
        (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=False)
        (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attention): Attention(
          (rotary_emb): RotaryEmbedding()
          (query_key_value): Linear4bit(in_features=4096, out_features=12288, bias=False)
          (dense): Linear4bit(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): MLP(
          (packed_input_proj): Linear4bit(in_features=4096, out_features=22016, bias=False)
          (out_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
          (act): SiLU()
        )
      )
    )
    (final_layer_norm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
  )
  (embed_out): Linear(in_features=4096, out_features=65536, bias=False)
)

ちなみにloraを適用した後のアーキテクチャは以下の様になります。

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): JapaneseStableLMAlphaForCausalLM(
      (transformer): JapaneseStableLMAlphaModel(
        (embed_in): Embedding(65536, 4096)
        (layers): ModuleList(
          (0-31): 32 x DecoderLayer(
            (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=False)
            (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
            (attention): Attention(
              (rotary_emb): RotaryEmbedding()
              (query_key_value): Linear4bit(
                in_features=4096, out_features=12288, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=12288, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (dense): Linear4bit(
                in_features=4096, out_features=4096, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
            )
            (mlp): MLP(
              (packed_input_proj): Linear4bit(
                in_features=4096, out_features=22016, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=22016, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (out_proj): Linear4bit(
                in_features=11008, out_features=4096, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=11008, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (act): SiLU()
            )
          )
        )
        (final_layer_norm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
      )
      (embed_out): Linear(in_features=4096, out_features=65536, bias=False)
    )
  )
)

4. 調べ方②

これまでの記事をゆっくり書いていたら、npaka先生が上記を調べるためのコードを記事に書いてくださっていました。
発想は同じでLinear層をアーキテクチャから検索して表示しているようです。

import bitsandbytes as bnb

def find_all_linear_names(model):
    cls = bnb.nn.Linear4bit  # (default:torch.nn.Linear,4bit:bnb.nn.Linear4bit,8bit:bnb.nn.Linear8bitLt)
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in lora_module_names: # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)

find_all_linear_names(model)

これでstablelmを調べると確かにquery_key_value、dense、packed_input_proj、out_projとなりました。

['out_proj', 'dense', 'packed_input_proj', 'query_key_value']

目で調べるよりもこちらを使った方が確実で良いですね。

この記事が気に入ったらサポートをしてみませんか?