見出し画像

パラメータ効率が圧倒的に高いLLM学習手法ReFT(Representation Finetuning)を試してみた。

こんにちは!株式会社IZAI、エンジニアチームです。
今回は従来のLLMファインチューニング手法よりも10~50倍効率的とされているReFT(Representation Finetuning)を試してみます。
現論文はこちら


1. ReFTとは


ファインチューニング

ReFTとはRepresentation Finetuningの名前の通りファインチューニングにおける学習法です。今回紹介する手法は以下の図の赤枠の部分になります。

図1 LLMの学習の全体像とReFTの位置付け
https://dalab.jp/archives/journal/llm-finetuning-part1/

ちなみに、この図はこちらから引用させていただいたのですが、LLMの全体像が大変わかりやすくまとめられているので、LLMがどのように学習されているか勉強したい方はぜひ読んでみてください。

図を見て分かる通りァインチューニングでは、すでに学習されたモデルを、適用したいタスクに合わせて再学習させ、モデルのパラメータを微調整していきます。

このとき、モデルの全てのパラメータを更新していると効率が悪いため、一部のパラメータのみを更新する手法が使われています。それら手法をPEFT(Parameter-efficient fine-tuning)と呼びます。

ReFTの立場

先ほど説明したように、ファインチューニングの際にモデルの一部のパラメータのみを更新する手法をPEFTと呼び、ReFTはその一種です。
まずは従来のPEFTとして一般的に利用されている、LoRA(ロラ、ローラ)という手法を紹介します。

従来手法 「LoRA」

従来のPEFTの代表的な手法としてLoRA(Low-Rank Adaptation)というものがあります。原著論文では以下図を用いて説明されています。
この手法は低ランクの分解行列を使用することでパラメータを部分的に更新する手法です(以下)。この手法に関してはこちらが大変参考になります。


図2 LoRAの概念図
https://arxiv.org/abs/2106.09685
図3 LoRAの具体的な計算方法
https://ledge.ai/articles/LoRA

今回紹介する新手法 「ReFT」

LoRAが低ランク行列を用いてパラメータを部分的に更新していたのに対し、ReFTは中間層に介入するのが特徴です。
原著論文では以下のように紹介されています。


図4 ReFTの概念図
https://arxiv.org/abs/2404.03592

本手法はニューラルネットワークの内部表現の解釈性の研究から発想を得ており、そこからRepresentationという名前がついています。

2. とりあえずReFTを動かしてみよう


2-1. ReFTのインストール

ReFTのpythonライブラリがGitHubで公開されているのでインストールしましょう。

pip install git+https://github.com/stanfordnlp/pyreft.git

2-2. データセットの用意

今回はこちらのお嬢様会話データセットを使用しました。

git clone <https://github.com/matsuvr/OjousamaTalkScriptDataset.git>
df = pd.read_csv('./OjousamaTalkScriptDataset/ojousamatalkscript200.csv')
sample_df = df.sample(20)

2-3. 事前学習モデルの読み込み

今回はこちらのモデルを使用します

prompt_no_input_template = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
device='cpu'

model_id = "rinna/llama-3-youko-8b"
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map=device,
    trust_remote_code=True)

tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_id, model_max_length=2048, 
    padding_side="right", use_fast=False)
tokenizer.pad_token = tokenizer.eos_token

2-4. データモジュールの作成

data_module = pyreft.make_last_position_supervised_data_module(
    tokenizer, model, [prompt_no_input_template % row['prompt'] for _, row in sample_df.iterrows()], 
    [row['completion'] for _, row in sample_df.iterrows()])

2-5. ReFTの設定

reft_config = pyreft.ReftConfig(representations={
    "layer": 8, "component": "block_output",
    "low_rank_dimension": 4,
    "intervention": pyreft.LoreftIntervention(embed_dim=model.config.hidden_size,
    low_rank_dimension=4)})
reft_model = pyreft.get_reft_model(model, reft_config)
reft_model.set_device(device)
reft_model.print_trainable_parameters()

2-6. 学習する

training_args = transformers.TrainingArguments(
    per_device_train_batch_size = 4,
    gradient_accumulation_steps = 8,
    warmup_steps = 100,
    num_train_epochs = 1,
    learning_rate = 5e-4,
    # bf16 = True,
    logging_steps = 1,
    optim = "paged_adamw_32bit",
    weight_decay = 0.0,
    lr_scheduler_type = "cosine",
    output_dir = "outputs",
    report_to=[]
)

trainer = pyreft.ReftTrainerForCausalLM(model=reft_model, tokenizer=tokenizer, args=training_args, **data_module)

_ = trainer.train()

学習したモデルをテスト

eval_df = df.iloc[[i for i in df.index if i not in sample_df.index]].sample(5)

results = []
for _, row in eval_df.iterrows():
    prompt = prompt_no_input_template % row["prompt"]
    prompt = tokenizer(prompt, return_tensors="pt").to(device)

    base_unit_location = prompt["input_ids"].shape[-1] - 1  # last position
    _, reft_response = reft_model.generate(
        prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
        intervene_on_prompt=True, max_new_tokens=512, do_sample=True,
        eos_token_id=tokenizer.eos_token_id, early_stopping=True
    )
    results.append({
        "output": tokenizer.decode(reft_response[0], skip_special_tokens=True),
        "completion": row["completion"]
    })

結果

ColabのT4 GPUで学習時間は18分でした。

おまけ - ReFTのコア概念についてClaude3と一緒に勉強してみた


原著論文の3章でReFTの詳しい説明がなされており、以下のキーワードが出てきます。

  1. Casual abstraction analysis

  2. Distributed interchange intervention

これらの論文を辿っても難解で説明が難しいのでClaude3に聞いてみました。
結論、Casual Abstraction Analysisはニューラルネットワークの内部表現を解釈するための手法で、Distributed Interchange Interventionはそのうちの具体的な手法の一つであることがわかります。
(要望を頂けたら解説記事を出します。)

1. Casual Abstraction Analysis

Claude3による説明
(私) :
ニューラルネットワークにおけるcasual abstraction analysisとはなんですか
(Claude3) :
Causal abstraction analysis は、ニューラルネットワークが学習した表現を理解するための手法の1つです。
具体的には、以下の2ステップからなります。

1. 入力データに対するニューラルネットワークの中間層の活性化パターンを可視化します。これにより、ネットワークがどのような表現を学習しているかを観察できます。

2. この手法を用いることで、ニューラルネットワークが学習した抽象的な概念表現を可視化し、理解を深めることができます。ただし、完全に中間表現の意味を解釈することは難しく、あくまで人間が判断する過程におけるひとつの手がかりとなります。

ニューラルネットワークの解釈可能性を高める研究分野では、causal abstraction analysis に加えて、概念活性化ベクトル、積層representational解散性分析などの様々な手法が提案されています。

NNの内部表現を解釈するための手法ということですね。

2. Distributed Interchange Intervention

(私) :
それに関連してdistributed interchange interventionについて説明してください。
(Claude3) :
distributed interchange interventionは、causal abstraction analysisの一種で、中間層の表現ベクトルを操作する際の具体的な手法です。
手順は以下の通りです。

1. モデルの中間層から、ある2つのユニットの活性化値ベクトル(u,v)を取り出す
2. uとvの値を入れ替えた新しいベクトル(u',v')を作る
3. 元のベクトル(u,v)を持つ入力データに対して、u'とv'を当てはめた入力データを生成する
4. 生成された入力データをモデルに通し、出力の変化を観察する

この操作を、モデル内の様々なユニットの組み合わせで行うことで、各中間表現ベクトルがモデルの出力にどの程度影響を与えているかを特定できます。
従来のcausal abstraction analysisでは、ユニットの値を0に設定したり、ランダムノイズを加えたりする操作が行われてきました。しかしdistributed interchange interventionでは、実際にモデル内で学習された表現ベクトルを用いて介入を行うため、より自然なかたちでユニットの意味を調べられるとされています。
ただし、この手法には以下のような課題もあります。
・操作が大規模になるとモデル出力が不自然になる可能性がある
・ユニット間の相互作用があると解釈が難しくなる
・現実世界での介入と厳密には等価ではない

このように一長一短がありますが、ニューラルネットの解釈可能性研究において有力な手法の一つとされています。

Distributed Interchange Interventionはそのうちの具体的な手法の一つということですね。

最後に


以上、参考になった方はいいねコメント頂けるととても嬉しいです。
株式会社IZAIでは、AI・ロボティクスを用いたサービス開発・研究をしております。興味をお持ちいただいた方はお気軽にお問い合わせください。それではまた!




参考


この記事が気に入ったらサポートをしてみませんか?