Weights & Biases (wandb) を用いたLLMファインチューニング
こんにちは。Weights & Biases Japanの山本です。今回はWeights & Biases (wandb) を用いたOpen LLMのファインチューニングについてご紹介したいと思います。
はじめに
自然言語処理 (NLP) の世界は驚くほどの速度で進化しています。特に、OpenAI社によってリリースされたChatGPTの登場以来、それまで着実な進歩を重ねてきた大規模言語モデル(Large Language Model:LLM)に一気に注目が集まり、今や多くの開発者や研究者がこれらのモデルの可能性を追求しています。
大規模言語モデルを活用する方法は多様ですが、ここでは大きく分けて3つのアプローチにカテゴライズしたいと思います。
基盤モデルから自社で構築する
オープンな学習済み基盤モデルをファインチューニングする
基盤モデル(主にOpenAIのモデルをAPI経由で)そのまま使う
1.のアプローチはまさにOpenAIやGoogleが行なっていることですが、最も大きな競争優位性を獲得し得る一方で、非常に高い技術力と戦略レベルでの投資が必要となり、多くの企業にとっては必ずしも選択肢とならないかもしれません。そこで、本稿では2.のファインチューニングについてご紹介したいと思います。このアプローチでは自社内の技術情報や顧客情報を反映させることができ、かつ社内で完結し得るために検討している企業が多いかと思います。
(なお、3.についても極めて強力なOpenAIのGPT-3.5/GPT-4をコアに、様々な外部ツールや巧みなコンテクスト情報の活用をLangChainなどの最新の開発フレームワークで統合していくエキサイティングな分野であり、このいわゆるプロンプトエンジニアリング領域でもWeights & Biasesが大活躍するのですが、これについては次回以降に譲りたいと思います。)
今回の実験内容
さて、既存の基盤モデルをファインチューニングすることに決めたわけですが、現時点で様々なモデルが公開されています。ここではモデルウエイトが公開されているオープンなモデルの中から、さらに商用利用可能であり、日本語データでAbeja社が学習して公開している"abeja/gpt-neox-japanese-2.7b"をチューニングしてみたいと思います。
ファインチューニングの内容は実はなんでも良かったのですが、ここではKagglerのツイートを学習させて、いわゆる「月刊Kaggleは役に立たない」的な会話をさせてみたいと思います。競技プログラミングやその他の趣味でもそうだと思うのですが、機械学習コンペティションのKaggleでも時折Kaggleに批判的な発言を元にSNS上で激論が巻き起こるというのが季節の風物詩になっていますので、これを再現してみたいと思います。"abeja/gpt-neox-japanese-2.7b"は比較的日本語は流暢に話してくれますが、Kaggleの知識はあまり持っていないため、チューニングによって詳しくなってくれることを期待したいと思います。
さて、ここではファインチューニングはParameter Efficientな手法であり、近年広く用いられているLoRA (Low-Rank Adaptation) で行います。学習中のvalidation lossなどの各種metricやGPU使用率などをWeights & Biasesでモニタリングするために、Hugging FaceのTrainerでreport_to="wandb"と以下の例のように指定しておきます。たったこの一行でリアルタイムかつインタラクティブに学習中の多くの情報を一元的に管理モニタリングすることができます。
training_args = TrainingArguments(
... # ここでその他の設定をします
report_to="wandb",# お手軽一行、これだけ!
)
さらに、学習中のモデルを用いて一定ステップごとに上記の「月刊Kaggleは役に立たない」的な会話を実際に発生させることで、仕上がり具合をリアルタイムに確認していきたいと思います。これは以下のようにコールバック内でLangChainに学習中のモデルをセットして、KaggleアンチとKaggleファンからなるSequentialChainにすることで簡単に実現できます。LangChainは2023年5月現在でかなり新しいライブラリですが、Weights & Biasesは既にインテグレーションがありますのでWandbTracer()と書くだけでより複雑なシークエンスであっても各Chainの入出力や処理内容、処理時間、ステータスなどが全て自動的に記録され、LLM専用のインターフェース上で確認することができます。
class JapanesePromptLoggingCallback(TrainerCallback):
def __init__(self, prompt, log_interval=100):
super().__init__()
self.prompt = prompt
self.log_interval = log_interval
def on_train_begin(self, args, state, control, **kwargs):
self.log_prompt_response(state=state, **kwargs)
def on_step_end(self, args, state, control, **kwargs):
if state.global_step % self.log_interval == 0:
self.log_prompt_response(state=state, **kwargs)
def log_prompt_response(self, state=None, global_step=0, **kwargs):
model = kwargs["model"]
model.eval()
# ...
# ここで学習中のモデルをLangChainにセットして、簡単な会話をさせるためのSequencialChainを構築します
# ...
response = overall_chain(input_statement, callbacks=[WandbTracer()]) # That's all!
model.train()
特に生成系モデルでは学習時のタスクと実際に行いたい下流タスクが異なることが通常ですので、lossだけ見ていても実際の状況が分かりづらいですし、また学習中のモデルの成長具合をリアルタイムかつ直接見るのは単純に楽しいことでもあります。以上をまとめると以下の図のようになります。
実験結果
さて、学習結果を見てみましょう。
以下のようにWeights & Biasesのダッシュボード上で、学習時のlossやGPUユーセージに加えて、今回はコールバックで多数発生させた会話内容とそこで用いたLangChainの構造と内部の入出力や処理時間がリアルタイムにこの画面にアップデートされていきます。
一つずつ詳しく見てみましょう。まず、学習時のlossやGPUメモリの使用率などについては言うまでもないかと思います。実際には学習関連のmetricsもシステム関連のmetricsも様々な項目を自動的に取得してくれていますが、ここでは特に関心のある項目に絞って表示しています。
次にLLM関連の項目を順に見ていきましょう。トレーステーブルでは、LLM/チェーンへの呼び出しの入力と出力、およびそのチェーンの構成要素とエラーを全てのトレースについて視覚的に確認することができます。リアルタイムでコラボレーティブな分析とそのイテレーションを可能にしてくれます。また、エクスポート機能も備えています。
トレースタイムラインでは、個々のトレースについてステップとアクティビティを可視化してくれます。個々のエージェント、チェーン、ツール、モデルがどのように相互作用しているのかが簡単に把握でき、それぞれの実行にどれくらいの時間がかかったのかもわかります。
具体的にモデルの学習進行に従ってどのように会話内容が変わっていくのか見ていきましょう。まずは学習前ですが、どうも両者ともにKaggleをデータを集める場所と勘違いしているようです。もちろん、Kaggle Datasetsはデータセットを共有する有用なプラットフォームですが、最もKaggleを特徴づけているのはモデルの予測精度を競うコンペティションですから、これはKaggleをよく知らない状態と言えるでしょう。
しばらく学習を進めると、以下のようにややそれっぽいことを言うようになってきました。「単にfitしてpredictしてるだけの作業じゃん」と言うのは良くある批判の仕方で、今回の実験の意図にかなっていると言えます。また、それに対する批判もコミュニティとしての交流機能について触れており、良い指摘と言えるでしょう。
その後しばらく学習を続けていくと、会話内容が微妙になっていくとともに、ハッシュタグや短縮URLがやたら付加されるようになっていきました。Tweetを学習していることによる影響で、ちゃんとデータを見て前処理しろと言えばそれまでですが、lossだけ見ていてもわからない品質低下をリアルタイムに捉えられたとも言えるかもしれません。
まとめ
以上、学習済みモデルをLoRAでファインチューニングしながら、LangChainと組み合わせることで学習中のモデルによる会話品質をリアルタイムに捉えることができました。Weights & BiasesはLLMの活用においてもHugging FaceやLangChainとの連携、プロンプトに特化したUIを含むWeights & Biases Promptsによって強力にMLエンジニアをサポートしてくれますので是非ご活用ください。