torchtuneとWandBを使ったLlama3.1のファインチューニングと自動評価
Weights & Biases のNoteをフォローしてください
はじめに
この一年あまりの間にOpenLLMの性能は飛躍的に向上してきました。その中でも、Meta社のが開発した大規模言語モデル:LlamaシリーズはOpenLLMに非常に強力なベースラインを築き、かつその水準を継続的に向上させてきました。その最新版がLlama3 (Llama3.1)です。本記事では、このLlama3の特徴と可能性、そしてそのポテンシャルを引き出すためのファインチューニング技術について深く掘り下げていきます。さらに、torchtuneとWandBのインテグレーションによって、ファインチューニングプロセスをいかに効率化し、自動化できるかを探ります。
Llama3とは
Llama3.1は、世界最大級かつ最も高精度なOpenLLMの一つです。その特徴として、多言語対応、コーディング能力、ツール使用の機能を持っています。
基本的な情報は以下の通りです:
パラメータ数:8B/70B/405Bが存在(405Bはオープンソースで世界最大級, 2024年8月時点)
対応言語:英語、ドイツ語、フランス語、イタリア語、ポルトガル語、ヒンディー語、スペイン語、タイ語
コンテキスト長:128K(Llama2は4k、GPT-4oは128k)
ライセンス:Llama 3.1 Community License(商用利用可能、モデルの出力を使用して他のモデルを改善することも可能、条件あり)
Llama3.1の405Bモデルは、弊社が運営するNejumi LLMリーダーボード3を含む各種ベンチマークの結果において、GPT-4o、Claude 3.5 Sonnetとに近い水準の性能を有することが示されています。さらに、軽量なモデルでも他の主要なモデルの重量級のサイズに匹敵ないし超える性能を有しています。
Llamaのファインチューニングとは
Llamaのファインチューニング¹⁾には多くの事例があり、以下のようなことが可能です:
応答性能を向上させる
マルチモーダル化
ドメイン特化
これらは、プロンプトエンジニアリングでは実現できず、事前学習を実施できるリソースがない場合に有効な手段となります。
ユースケース①:日本語化
バージョンアップに伴って改善されているものの、Llamaは日本語対応が不十分なため、日本語が怪しい部分があります。そこで、Llamaを日本語データでファインチューニングすることで、日本語能力を向上させることができます。例えば、Llama 3 Swallowは日本語データでファインチューニングを実施し、日本語能力を大幅に向上させました。学習データによっては、日本語化だけでなく、話し方の指向性を変えることも可能です。
ユースケース②:ドメイン特化(金融・医療など)
医療ドメイン特化のLlama3として、Llama3-Preferred-MedSwallow-70Bがあります。このモデルは日本医師国家試験において、GPT-4を上回る成績を記録しました(2024/7/17時点で初めて日本医師国家試験を合格できるスコア)。QLoRAを使用することでA100 2台で学習を実施できたことも特筆すべき点です。
このように、ファインチューニングを活用することで、少ないリソースで特定のドメインに特化したLLMを開発することが可能となります。
torchtuneとは
torchtuneは、LLMのファインチューニングと実験のためのPyTorchライブラリです。Llamaと同じくMata社が開発元であることもあり、Llama3.1にも当初より対応しているなど親和性が高い点もポイントです。
torchtuneの大まかな構造
torchtuneは大まかには以下の3つの主要コンポーネントから構成されています:
torchtuneのレシピとは、LLMの学習や評価のための特定のタスクに焦点を当てたパイプラインで、上図のようにYAML設定、レシピスクリプト、レシピクラスから構成されています。これらのレシピの各コンポーネントは、torchtune開発チームによってよくテストされ、最適化されています。
torchtuneの特徴は、この階層構造によって、レシピ通りの内容をYAMLで手軽に実行できる点と、同時にネイティブPyTorch実装で思う存分レシピをカスタマイズできる点です。つまり、簡便性と必要に応じた高度なカスタマイズ性を両立したフレームワークといえます。Hugging Face/TRLのSFTTrainerも簡単にファインチューニングが可能ですが、カスタマイズは基本的にハイレベルAPIの範囲内での設定変更に限られる点が、torchtuneとの主な違いです。
torchtuneの基本的な使い方
1.インストール
git clone https://github.com/pytorch/torchtune.git
cd torchtune
pip install -e .
2.モデルのダウンロード
tune download meta-llama/Meta-Llama-3.1-8B \
--output-dir /tmp/Meta-Llama-3.1-8B/ \
--hf-token <ACCESS TOKEN>
3.レシピをお好みでカスタマイズ
4.チューニングの実行
tune run lora_finetune_single_device --config llama3_1_wandb_lora.yaml
torchtuneとwandbのインテグレーション
torchtuneはWeights & Biases(W&B)と連携することで、効率的なモデル開発と管理が可能になります。主な連携ポイントは以下の2つです:
WandBLoggerによるメトリクスのロギング
wandb.Artifactへのモデル保存
前者の連携は、YAMLファイルの設定により簡単に、後者の連携はレシピクラスのsave_checkpointメソッドのカスタマイズにより柔軟に実現できます。
カスタマイズ例
①YAMLレベルでのカスタマイズ
以下の例では、datasetにsourceの設定を追加することで、指示チューニング用のデータセットをデフォルトの設定 ('tatsu-lab/alpaca')からDolly日本語訳 ('kunishou/databricks-dolly-15k-ja') 変更しています。さらにmetric_loggerにWandBLoggerを追加することでメトリックをW&Bにロギングするように設定しています。ここでprojectはロギング先のWandBのプロジェクト名です。
# Dataset
dataset:
_component_: torchtune.datasets.alpaca_dataset
train_on_input: True
source: kunishou/databricks-dolly-15k-ja
seed: null
shuffle: True
# Logging
metric_logger:
_component_: torchtune.utils.metric_logging.WandBLogger
project: llama3_lora
log_every_n_steps: 1
上記変更を加えたYAMLファイルを保存して、再度tune runすると、実験のメトリクスやGPUユーセージなどのシステムメトリクスがWandBがロギングされて、下図のように可視化されます。
上記のカスタマイズを加えたYAML設定も自動的にキャプチャされており、WandBのFilesタブから以下のように確認することができますので、高い再現性を担保することができます。
②レシピクラスの編集
既存のレシピの範囲で設定を変えるだけであれば、上記のようにYAMLで簡単に実行できますが、torchtuneの強みはカスタマイズ性の高さです。レシピクラスのsave_checkpointメソッドを編集して、LoRAアダプターのみを保存し、W&Bにアーティファクトとしてログするように以下のように変更してみましょう。さらにモデルアーティファクトをレジストリに登録します。
def save_checkpoint(self, epoch: int) -> None:
"""
アダプターとアダプターコンフィグを保存し、W&Bにアーティファクトとして
ログする
"""
# アダプターウェイトとコンフィグの取得
adapter_state_dict = {...} # アダプターウェイトの取得
adapter_config = {...} # アダプターコンフィグの作成
# チェックポイントの保存
checkpoint_file = Path(self._checkpointer._output_dir) / f"adapter_checkpoint_{epoch}.pt"
torch.save(ckpt_dict, checkpoint_file)
# W&Bにチェックポイントをログ
wandb_at = wandb.Artifact(
name=f"adapter_checkpoint_{epoch}",
type="model",
description="LoRA adapter checkpoint",
metadata={...} # メタデータの設定
)
wandb_at.add_file(str(checkpoint_file))
wandb.log_artifact(wandb_at)
wandb.link_artifact(wandb_at, 'meta-llama-Meta-Llama-3_1-8B-instruct-lora-adaptor')
WandBのアーティファクトとしてログすることで、そのモデルの学習から推論、用いたデータセットなどのリネージが有向グラフとして可視化され、一連の実験のトレーサビリティを担保します。
Registry Automationによるモデルの自動評価
先ほどのレシピクラスのカスタマイズ例では、モデルチェックポイントをアーティファクトとしてログするだけでなく、レジストリへの登録も同時におこないました。レジストリにモデルを登録することで、モデルのバージョン管理やライフサイクル管理のより高度な実践が可能になり、またW&B Automationsを介して各種自動化フローの起点として機能します。この機能は、レジストリに新しいモデルの登録、ないし指定のタグ(例: 'production', 'staging')が付与された時に後述のWandB LaunchのジョブないしWebhooksをキックすることができます。
先ほどのレシピクラスでは指定されたエポックごとにモデルチェックポイントをWandBアーティファクトとして保存し、レジストリに登録するようにカスタマイズしていました。そのため、任意のジョブをこのレジストリにAutomationsとして登録することでそれらを指定エポックごとに自動実行し、WandBにロギングする仕組みを容易に構築することができるのです。一例としては、MT Benchによる評価をジョブとして登録することでファインチューニングの進行中にマルチターン会話形式でのLLM-as-a-Judgeのスコアをリアルタイムにモニタリングし続けることができます。
このスキームはLLM評価以外にも様々なユースケースに適用な可能で、その他に以下のような応用が挙げられます。このように、レジストリはAutomationsを介して様々な機能との連携を行うハブのようにも機能するのです。
Sweeps on Launch
KubernetesクラスタやAWS SageMakerのようなサービスへの水平スケーリングにより、ハイパーパラメータ探索を大幅に短縮
Optunaインテグレーションを用いて、より高度な(サンプリング戦略やプルーニングなど)をスケーラブルに実行
Automations
モデルデプロイ(量子化、フォーマット変換、エッジデバイスへのデプロイ準備、マネージドエンドポイント等へのデプロイ)
モデル評価
データセットの前処理
まとめ
torchtuneはOpenLLMを牽引するモデルの筆頭であるLlamaシリーズとも相性が良く、簡便性と必要に応じた高度なカスタマイズ性を両立したフレームワークです。その他のLLM開発エコシステムとのインテグレーションも進んでおり、Hugging FaceやWandBと組み合わせて使用することができます。特にWandBとのインテグレーションはtorchtune自体の設計の柔軟さによって高いポテンシャルを有しており、ファインチューニング中にダウンストリームタスクでの自動評価などへの拡張も容易に実現可能です。LLMのファインチューニングに興味のある方はぜひ試してみてください。