パラメータ効率が圧倒的に高いLLM学習手法ReFT(Representation Finetuning)を試してみた。
こんにちは!株式会社IZAI、エンジニアチームです。
今回は従来のLLMファインチューニング手法よりも10~50倍効率的とされているReFT(Representation Finetuning)を試してみます。
現論文はこちら
1. ReFTとは
ファインチューニング
ReFTとはRepresentation Finetuningの名前の通りファインチューニングにおける学習法です。今回紹介する手法は以下の図の赤枠の部分になります。
図を見て分かる通りァインチューニングでは、すでに学習されたモデルを、適用したいタスクに合わせて再学習させ、モデルのパラメータを微調整していきます。
このとき、モデルの全てのパラメータを更新していると効率が悪いため、一部のパラメータのみを更新する手法が使われています。それら手法をPEFT(Parameter-efficient fine-tuning)と呼びます。
ReFTの立場
先ほど説明したように、ファインチューニングの際にモデルの一部のパラメータのみを更新する手法をPEFTと呼び、ReFTはその一種です。
まずは従来のPEFTとして一般的に利用されている、LoRA(ロラ、ローラ)という手法を紹介します。
従来手法 「LoRA」
従来のPEFTの代表的な手法としてLoRA(Low-Rank Adaptation)というものがあります。原著論文では以下図を用いて説明されています。
この手法は低ランクの分解行列を使用することでパラメータを部分的に更新する手法です(以下)。この手法に関してはこちらが大変参考になります。
今回紹介する新手法 「ReFT」
LoRAが低ランク行列を用いてパラメータを部分的に更新していたのに対し、ReFTは中間層に介入するのが特徴です。
原著論文では以下のように紹介されています。
本手法はニューラルネットワークの内部表現の解釈性の研究から発想を得ており、そこからRepresentationという名前がついています。
2. とりあえずReFTを動かしてみよう
2-1. ReFTのインストール
ReFTのpythonライブラリがGitHubで公開されているのでインストールしましょう。
pip install git+https://github.com/stanfordnlp/pyreft.git
2-2. データセットの用意
今回はこちらのお嬢様会話データセットを使用しました。
git clone <https://github.com/matsuvr/OjousamaTalkScriptDataset.git>
df = pd.read_csv('./OjousamaTalkScriptDataset/ojousamatalkscript200.csv')
sample_df = df.sample(20)
2-3. 事前学習モデルの読み込み
今回はこちらのモデルを使用します
prompt_no_input_template = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
device='cpu'
model_id = "rinna/llama-3-youko-8b"
model = transformers.AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map=device,
trust_remote_code=True)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_id, model_max_length=2048,
padding_side="right", use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
2-4. データモジュールの作成
data_module = pyreft.make_last_position_supervised_data_module(
tokenizer, model, [prompt_no_input_template % row['prompt'] for _, row in sample_df.iterrows()],
[row['completion'] for _, row in sample_df.iterrows()])
2-5. ReFTの設定
reft_config = pyreft.ReftConfig(representations={
"layer": 8, "component": "block_output",
"low_rank_dimension": 4,
"intervention": pyreft.LoreftIntervention(embed_dim=model.config.hidden_size,
low_rank_dimension=4)})
reft_model = pyreft.get_reft_model(model, reft_config)
reft_model.set_device(device)
reft_model.print_trainable_parameters()
2-6. 学習する
training_args = transformers.TrainingArguments(
per_device_train_batch_size = 4,
gradient_accumulation_steps = 8,
warmup_steps = 100,
num_train_epochs = 1,
learning_rate = 5e-4,
# bf16 = True,
logging_steps = 1,
optim = "paged_adamw_32bit",
weight_decay = 0.0,
lr_scheduler_type = "cosine",
output_dir = "outputs",
report_to=[]
)
trainer = pyreft.ReftTrainerForCausalLM(model=reft_model, tokenizer=tokenizer, args=training_args, **data_module)
_ = trainer.train()
学習したモデルをテスト
eval_df = df.iloc[[i for i in df.index if i not in sample_df.index]].sample(5)
results = []
for _, row in eval_df.iterrows():
prompt = prompt_no_input_template % row["prompt"]
prompt = tokenizer(prompt, return_tensors="pt").to(device)
base_unit_location = prompt["input_ids"].shape[-1] - 1 # last position
_, reft_response = reft_model.generate(
prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
intervene_on_prompt=True, max_new_tokens=512, do_sample=True,
eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
results.append({
"output": tokenizer.decode(reft_response[0], skip_special_tokens=True),
"completion": row["completion"]
})
結果
ColabのT4 GPUで学習時間は18分でした。
おまけ - ReFTのコア概念についてClaude3と一緒に勉強してみた
原著論文の3章でReFTの詳しい説明がなされており、以下のキーワードが出てきます。
Casual abstraction analysis
Distributed interchange intervention
これらの論文を辿っても難解で説明が難しいのでClaude3に聞いてみました。
結論、Casual Abstraction Analysisはニューラルネットワークの内部表現を解釈するための手法で、Distributed Interchange Interventionはそのうちの具体的な手法の一つであることがわかります。
(要望を頂けたら解説記事を出します。)
1. Casual Abstraction Analysis
NNの内部表現を解釈するための手法ということですね。
2. Distributed Interchange Intervention
Distributed Interchange Interventionはそのうちの具体的な手法の一つということですね。
最後に
以上、参考になった方はいいねやコメント頂けるととても嬉しいです。
株式会社IZAIでは、AI・ロボティクスを用いたサービス開発・研究をしております。興味をお持ちいただいた方はお気軽にお問い合わせください。それではまた!
参考
この記事が気に入ったらサポートをしてみませんか?