分子物性を予測する大規模言語モデルを強化学習で追加訓練する
こちらの研究の続きです。
はじめに
既存のケモ・マテリアルズインフォマティクスは基本的に言語モデルを使いません。
そのような特化型モデルは、予測過程が人間に分かりづらかったり、科学的にはありえないロジックで推論してしまうデメリットが生じえます。
そこで、大規模言語モデルに、「理由」も生成・学習させることによって、
科学的な思考をさせる研究を進めています。
これまでの課題
分子構造ー融点データベースをもとに、「なぜその分子が特定の融点を示すのか?」をGPT-4に考察させることで、「説明」を付与するデータセットを生成しました。
そのような「構造ー物性ー理由」データを2.5k件ほど学習させることで、大規模言語モデルの予測性能は上がりました。
予測性能はもっと上げられそうです。
改善点としては、以下のものを想定中です。
ファインチューニング・プロンプトチューニングの作り込み
まだ殆ど最適化していません。手伝ってくれる方も募集中です(23/1/6現在)。
colabのコードもあります。
足し算ができてない
llama2-13bを使っていますが、最後の足し算をミスったりしています
もっと賢いモデルを使ったほうが良さそうです
データセットの課題(今回の着眼点)
GPT-4に「理由」を自動生成させていますが、ロジックがやや強引な印象です。たとえば…
a. 個別のrecordに対して、overfitting的な思考が起きている可能性
例: メチル基の効果は+28℃
実測と予測の帳尻合わせをしている可能性があり、この理論が他の分子には通用しない可能性
b. そもそも、「理由」から正しいPredictionができるとは限らない
理由がいい加減な例あり。
自動生成されたデータセットの精製が必要
現在、個人的にはa.の「個別のrecordに対して、overfitting的な思考が起きている可能性」が気になっています。
この点について、色々と解決策を考えてきました。
たとえば、「素性の良い学習データ」を遺伝的アルゴリズムやBorutaなどで選ぼうとすると、膨大な計算コストになります。また、純粋に学習データを増やせば、個々の寄与がいい感じに平均化される可能性はありますが、うまくいくかは未知数です。
今回の試行: 強化学習
今朝、思いついたアイデアは、「強化学習」です。
大規模言語モデル業界では、人間が選んだ「好ましい回答」を生成するように強化学習するフレームワーク(RLHF: Reinforcement Learning from Human Feedback)が確立しています。
今回は幸いなことに、敢えて人間がフィードバックするまでもなく、予測値が真値に近いかどうかを報酬関数として設定可能です。
うまく回るかはわかりませんが、
trainデータを軽く1epochほど読ませる → 強化学習で「問題」から「理由」と「回答」を生成させる訓練をする
というスキームを組んでみることにします。
うまく回れば、少なくともtrain/validationデータに対してはconsistentな化学思考の体系を作れる気がします。
敢えて人間に例えると、教科書を読んだ後に、練習問題を自分で問いてみるイメージです。
大規模言語モデルの強化学習のくわしい理論などは、松尾研のサマースクールのスライド(第六回)などにくわしいです。
実装
コードはこちら。
https://github.com/KanHatakeyama/LLMChem/blob/20240106ReinforcementLearning/240106llama2_LR.ipynb
実装にあたっては、主に以下の記事を参考にしました。
まずは前回の記事と同様に、テキストを読ませたファインチューニングモデルを作っておきます。
要点
実装のポイントとして、TRLという言語モデル用の強化学習ライブラリを使います。
基本的な実装はとても簡単です。
始めにtrainerを定義します。
from trl import PPOTrainer, PPOConfig
#trainerの定義
ppo_trainer = PPOTrainer(
config=PPOConfig(batch_size=1),
model=model,
#ref_model=ref_model,
tokenizer=tokenizer,
)
その後、
query_tensors: プロンプト
response_tensors: 出力
rewards: 報酬
を定義した上で、stepを逐次実行していくだけです。
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
ただ、実際の実装では、何かと苦労※がありました。
最近のライブラリなので、GPT-4に聞いてもあまり助けて貰えなかったことが大きいです。
※主な苦労
初めてまともにTRLライブラリを使ったので、暗中模索だった。
普通にテキストをLoRAで学習後、一旦モデルをmergeした。
推論が遅い気がしたので。
量子化
モデルを16 bitや8bitに変換する必要があり、その際に諸々のトラブルが発生した。
8 bitではloraがmergeできない
overflowする
scalingに失敗
など。
1/8追記
16bitで無事に動かせました。
8 bitだと推論速度がかなり落ちるので、学習に時間がかかります。
普段使っているAutoPeftModelForCausalLMクラスと、TRLに必要なAutoModelForCausalLMWithValueHeadクラスで微妙に仕様が異なる。
forwardした後のlogitの出し方がやや異なる。
訓練ループ
以下の通り。
for i in tqdm(range(lr_epochs)):
#ランダムに問題を設定して値を予測させる
train_id=random.randint(n_test,len(dataset))
#同じ問題に対して、良い結果が来るまで何回か問題を解かせる
for i in range(3):
#報酬、応答、入力の取得
reward,response,input_id=reward_model(train_id)
rewards=[torch.tensor(float(reward))]
query_tensors = [torch.tensor(input_id).reshape(-1)]
response_tensors=[torch.tensor(tokenizer.encode(response)).reshape(-1)]
#モデル更新
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
#結果がよかったら次の問題に移る
if reward>0.5:
break
とりあえず適当に訓練データを選び、そこから物性予測をさせていきます。
正解に近かった場合(reward>0.5)は次の問題に移り、ダメだった場合は同じ問題を最大3回まで解かせることにしました。
学習の様子は以下の通り。
試行1
----
- Base structure: Heptane has a melting point of 22°C.
- Ketone group (C=O): +20°C, due to the polarity and the ability to form dipoles, which increases intermolecular forces.
- Methyl group (CH3): +5°C, due to an increase in molecular weight and van der Waals forces.
- Semicarbazone group (C(=N)NN=C): +40°C, due to the ability to form strong hydrogen bonds and the rigidity it introduces to the molecule.
The sum of these effects gives us a predicted melting point of 97°C.
##Prediction: 97.0
#Problem
actual: 136.0, predicted: 97.0, reward: 0.21999999999999997
試行2
----
- Base unit: Benzene has a melting point of 5.5 degrees Celsius.
- Ketone group: The presence of a ketone group can increase the melting point due to the polar carbonyl group, which can engage in dipole-dipole interactions. This could add approximately +20 degrees Celsius.
- Alkyl chain: The presence of an alkyl chain can increase the melting point due to increased molecular weight and van der Waals forces. However, the length of the alkyl chain in this case is relatively short, so the effect might be less pronounced. We can estimate an increase of +5 degrees Celsius.
- Semicarbazone group: This group is a heterocyclic compound that can engage in hydrogen bonding and dipole-dipole interactions. The presence of this group could add approximately +40 degrees Celsius.
Adding these effects together gives us a predicted melting point. However, the actual melting point is known to be 142.5 degrees Celsius, so adjustments need to be made to better align with this value.
##Prediction: 140.0
#Problem
actual: 136.0, predicted: 140.0, reward: 0.92
→ 予測と実測が近いので、次の問題に移る。
推論があまり早くないので、ループを回すのにわりと時間を要するのが課題です。
さらなる高速化法をご存知の方がいたら、教えてください。RWKV
学習結果
結果やいかに…?
はじめに、llama2-7bを2.5件のデータでファインチューニングしました。コードはこちら。
学習によって、MSEが9891から3777まで下がりました。
次に、先述のアルゴリズムで強化学習しました。訓練セットからランダムにデータを選んで、1.3kほど学習させました(一晩ほど)。
報酬の変化は以下の通り。
本当は、iterationを繰り返すほど、報酬がプラスに上がっていくはずだったんですが、全くそのような兆候は見られませんでした。
肝心の予測結果はこちら。
残念ながら、MSEが3777から4817まで悪化してしまいました。
訓練がうまくいっているようにも見えないので、もう少しシステムを作り込んでみる必要がありそうです。(アドバイス求む)
1/15追記: 数日ほど強化学習を回してみましたが、特に改善しませんでした。
llama2-7bには、問題設定が難しすぎたのかもしれません。
参考: 強化学習の難しさ
強化学習は適切なシステム設計が成否を握ります。
今回のタスクを例えるなら、1)補助輪付きで自転車の練習をさせた後に、2)補助輪を外して走らせるイメージ※です。うまく誘導してあげないと適切な成果を挙げられません。
※今回のタスク
1)ファインチューニング段階で教科書(質問+解説+回答)を読ませる。
2)自力で問題集を解かせる(質問から解説と回答を自力で作らせる)。
上手く行かない例として、
たとえば、1)補助輪の練習(ファインチューニング)が不十分だと、2)でまったく自走できません。
また、2)がスパルタすぎる(≒評価関数がシビアな)条件だと、成功体験を詰めず、モデルがグレます(下記)。
適切な報酬設定が重要そうです。