7Bモデルをドメイン特化させる学習をLoRAとGaLoreで試し比較する
コーディングや医療など特定のドメインに特化させることで高い性能を発揮するモデルはドメイン特化モデルと呼ばれ、ベースモデルにドメインのコーパスを追加で学習させることで作成されます。
この図はベースモデルからドメイン特化のモデルを得るまでのフローです。本記事では赤枠で囲っている③のドメイン特化学習を試した内容を紹介します。
![](https://assets.st-note.com/img/1712838547805-7L028FGKLm.png?width=1200)
GaLoreについて
2024/3に新しいLLMのファインチューニング手法GaLoreが公開されました。論文によれば、VRAM24Gのコンシューマ向けGPUで7Bモデルの事前トレーニング(図の①や②)もできる手法です。
使い方
Transformersにはv4.39.0から組み込まれており、今までのトレーニングコードを変えずに、TrainingArgumentsにoptimとoptim_target_modulesを指定するだけで利用できます。
args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw",
optim_target_modules=["attn", "mlp"]
)
https://github.com/huggingface/transformers/releases/tag/v4.39.0
GaLoreのハイパーパラメータはデフォルトでは以下の値になっています。
galore_optim_kwargs = {
"rank": int(optim_args.pop("rank", 128)),
"update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
"scale": float(optim_args.pop("scale", 0.25)),
"proj_type": optim_args.pop("proj_type", "std"),
}
学習データセット
ドメイン特化の学習をするために、特定のドメインのコーパスをある程度のまとまった量用意します。
今回はWebから収集されたR-18小説コーパスを利用しました。
さらに、コーパスのフィルタリングとしてModerationモデルoshizo/japanese-sexual-moderation-v2によってページ単位でスコアリングを行い、例えば長編小説の導入部などスコアの低い部分を除外しました。
フィルタリングによりR-18記述の多いページのみを学習対象とすることでデータ量をかなりの量減らしつつ、質を向上(?)させます。
結果として約6億文字程度(2文字1トークン計算で0.3Bトークン)の学習データを用意しました。
これは事前学習と比べるとはるかに少ない量ですが、スクラッチの事前学習ではなく日本語ベースモデルへの追加学習のため、この程度のデータ量でも学習効果が出ることを期待しています。
評価方法とベースモデルの選択
まずは、継続学習させるベースモデルを選ために、評価方法を決めます。
このドメインの評価データは知る限りでは存在しないため、自動評価が可能な評価データセットを作成しました。
ドメインの専門用語を答えさせるQAデータセット(50件)
ドメインの専門用語(1と同じ)を含む文を用意し、候補の単語と比較してその用語が適切と判断させるデータセット(50件)
専門用語については、Wikipediaでピンク色の警告マークがあるページのタイトルやイラストサイトのタグなどを参考に、人手の作業によって50の専門用語を選定しました。
データをそのまま書くとNoteにBANされそうなので一般的な言葉を例にデータセットの内容を紹介します。
「ドメインの専門用語を答えさせるQAデータセット(50件)」の形式はこのようなデータセットです。
質問:鉄道と道路が平面で交差している場所を何という?
回答:踏切(踏切り、踏み切りでも正解)
このようなQAを3-shotで答えさせ、完全一致の正解率を測ります。
今回は指示応答性能ではなくモデルに含まれる知識を測りたいため、指示チューニングモデルではなくベースモデルを使用し、few-shotにすることでベースモデルでも単語で回答が出力されることを確認しつつ評価しました。
「ドメインの専門用語を含む文を用意し、候補の単語と比較してその用語が適切と判断させるデータセット(50件)」は、専門用語を正しく説明している文と、類似の候補単語を5~10個用意したデータセットです。
説明文:{word}とは、鉄道と道路が平面で交差している場所のことである。
候補:踏切, 電車, 線路, 踏切, 道路, 停留所, 交差点
この文の{word}の部分に「踏切」を入れたのが正しい文で、他の候補を入れたものを不正解とします。
ベースモデルでそれぞれの候補を入れた文のlog生成確率をスコアとして候補の単語をランキングし、1/正しい文の順位を使って評価しました。
(1位なら1.0、2位なら0.5、3位なら0.333…など)
QAデータの作成作業は、ドメイン的にGPT-4が手伝ってくれないため手作業となり大変なのですが、ちょうど作業の途中でリリースされたcommand-r-plusを活用することでかなり省力化できました。
ベースモデルの比較
まずは、学習のベースとなるモデルを選定します。
日本語の主要なベースモデルを上記のデータセットで比較し、スコアの良いものを選択することにしました。
候補のモデルは、最近はモデルをマージさせることで様々な能力を強化することができるようになっているため、派生モデルの数が多くマージに参加させやすいLlama-2もしくはMistralの派生モデルのみを対象としました。
QAの完全一致正解率で降順ソートした結果です。
![](https://assets.st-note.com/img/1712843138111-x9Qoch7M4f.png?width=1200)
calm2-7bが良い結果となったため、ベースモデルとしてcyberagent/calm2-7bを選択し、さらに知識を向上させられるかを検証することにしました。
学習
環境は2xRTX3090の環境を使用し、学習率は1e-4、batchsizeは32(勾配累積込み)、max_lengthは1024です。
以下の4つの設定の初動(それぞれ8h程度の時点)を比較しました。
GaLore rank=128(デフォルト設定) … 紫
GaLore rank=1024 … 茶
LoRA rank=8 .. 緑
LoRA rank=256 … 白
train lossのグラフをみると、GaLoreのデフォルト設定(紫)が最も悪く、LoRAのr=256(白)が他と比べてかなりよさそうに見えます。
LoRAのrankは大きく学習結果に影響しないイメージを持っていましたが、今回のデータではrankを大きくすることでlossの下がり方がかなり変わるように見えます。
![](https://assets.st-note.com/img/1712842478806-qLd81wwt28.png?width=1200)
1epoch時点では、GaLore rank=128(紫)のlossは2.7程度だったのに対し、LoRA rank=256(白)は2.55と、大きく差が付きました。
![](https://assets.st-note.com/img/1712842700459-5qktnpkbCC.png?width=1200)
GaLoreは事前学習では有効なのかもしれませんが、ドメイン特化させる学習の場合はLoRAでも十分なのかもしれません。
VRAM消費はGaLoreで36~40GB(勾配チェックポイント=True)、LoRA rank=256で34GB程度(勾配チェックポイント=Falseにしたため多め)程度でした。
学習結果の評価
学習した結果のモデルに対し、同じように評価データセットでの評価を実施しました。
![](https://assets.st-note.com/img/1712843166542-Bc0iQDHpFh.png?width=1200)
ドメインの専門用語を答えさせるQAデータ(QA_exact)は+0.08、ドメインの専門用語を含む文を用意し、候補の単語と比較してその用語が適切と判断させるデータ(logp_mrr)は+0.09と、どちらもベースモデルと比べて大きく向上しました。
今回学習させたデータセットは小説形式のデータのみで、QA形式のデータやWikipediaのような説明文のデータは含んでいませんでしたが、そのようなデータからも専門用語の知識が得られているようで、なかなか面白いと思います。
以下はモデルごとに、どの質問番号を正解したかを可視化したものです。
学習によってできなくなってしまった問題もあるようです。
![](https://assets.st-note.com/img/1712927635254-lvnG2T8ScX.png?width=1200)
今後
3月にはmergekitを使ったevolutionary-model-mergeの論文や、Optunaを使ってmoeに使うtask vectorの加算割合をチューニングして作られたモデルLightChatAssistant-2x7B-optimized-experimentalなど、自動評価できる指標に基づいて良いマージモデルが作れることが実証されてきています。
私が今回行っている学習も、単独のモデルとして使うのではなくtask vectorとして他のモデルに加えたり、マージに利用する用途を想定しています。
マージの自動最適化のためには、低コストで自動評価できる評価用データセットが必要です。今回作ったドメイン知識を確認する評価データセット以外に、自分がやりたいタスクの評価データセットを整備してマージの自動最適化を試そうと考えています。