見出し画像

速報:話題の 1ビットLLMとは何か?

2024-02-27にarXiv公開され,昨日(2024-02-28)あたりから日本のAI・LLM界隈でも大きな話題になっている、マイクロソフトの研究チームが発表した 1ビットLLMであるが、これは、かつてB-DCGAN(https://link.springer.com/chapter/10.1007/978-3-030-36708-4_5; arXiv:https://arxiv.org/abs/1803.10930 )という「1ビットGANのFPGA実装」を研究していた私としては非常に興味をそそられる内容なので、論文を読んでみた。今回は速報として、その内容のポイントを概説したい。

論文情報
Ma, S. et al. (2024) ‘The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits’, arXiv [cs.CL]. Available at: http://arxiv.org/abs/2402.17764.

そもももこれは何なのか?

  • LLMのモデルで中心的役割を果しているトランスフォーマー(Transformer)の、入力および出力の埋め込み変換行列 $${W}$$ の数値の精度を大幅に削減(1ビット!化)して小型化したもの。

  • 普通のトランスフォーマーでは、この部分の数値には16ビットや32ビットの浮動小数点が使われているが、1ビットモデルではそこを1ビット整数で済ませてしまう。

  • このように、計算精度を大幅に削減しても(驚くべきことに!)言語モデルとしての性能はさほど悪化しない。(いくつかの条件を整える必要がありますが、非現実的なものではありません。詳細は原論文参照)

  • 精度の削減には BNN(https://arxiv.org/abs/1602.02830)のアイデアがほぼそのまま使われている。

  • 正確には 1-bit すなわち2値化ではなく3値化が使われており、その情報量は$${ \log_2^3\approx1.58 }$$なので、これが論文のタイトルになっている。

何が嬉しいのか?

  • トランスフォーマーはLLMなどの生成AIにおいて重要であると同時に、それらのAIにおいて最もコストがかかる部分でもある。これを1ビット化すれば、推論用途ではメモリ容量で十数分の1に小型化できると同時に、複雑で消費電力の多い浮動小数点演算回路=GPUを大幅に削減できる

  • すなわち、巨大なクラウド基盤などに頼らなくても学習済みの生成AIが動かせることになるので、スマホや各種組み込み機器でChatGPTレベルのAIを単独で動作できる可能性が出てくる。つまり、ありとあらゆる工業製品(日本なら自動車、家電製品、自動販売機?など)にLLMが組み込まれる=「あらゆる機械がしゃべりだす」時代が、ぐっと近づいたのである。

  • (一方、推論用のGPUへのニーズが相対的には減るので、NVIDIA社のようなGPUメーカーのビジネスにとっては逆風となりかねない。)

仕組み

この 1.58 ビットトランスフォーマーのもとになったモデルが BitNet(https://arxiv.org/abs/2310.11453) である。そしてさらに BitNet のもとになったアイデアが BNN である。(※ ちなみにBNN論文の著者には、あの Y. Bengioさんも名を連ねている。)

BNNでは1ビット化の対象は畳み込みNN(CNN)であったが、BitNetでは同じ数学的原理をTransformerに適用したものだ。

というわけで、BNNの仕組みを理解すれば、BitNetも1.58ビットモデルも理解できる。というわけで、BNNの仕組みを見てみよう。

学習中の重みは高精度のまま、順伝搬は1ビット化、逆伝搬はSTEで

BNNにおいて1ビット化を目指す対象は、学習の結果最終的に得られるモデル、つまり学習済みモデルである。一方、学習途上のモデルは深層学習のキモであるバックプロパゲーション(誤差逆伝搬)を使いたいので、そこは1ビット化せず従来どおり高ビットの浮動小数点を使うというのがポイントである。

モデル全体を1ビット化する必要はなく、ある層は1ビット、別の層は多ビットという形で層別にセッティングを変えることもできる。

具体的には1ビット化したい層について、学習中の重みは従来通り浮動小数点値で持っておき、順伝搬の計算のときはその層の出力をバイナリ変換(+1 or -1)して次の層に渡す、という細工をしてやる。(学習が収束した後、この浮動小数点の重みを同じ変換則で1ビット化して取り出せば、それが1ビット版のランタイムモデルとなる。)

一方の逆伝搬では、勾配の計算が必要だが、ここにも一工夫がいる。計算グラフにあるバイナリ変換関数は不連続関数なのでそのままでは微分できない。そこでSTE(Straight-Through Estimator)という手法で近似的に微分を実現している。

STEでは、下流側から伝搬されてきた勾配値 $${r}$$ の値が $${-1 \le r \le +1 }$$ならば微分値を r とし、$${r < -1}$$ ならば微分値は $${-1}$$, $${+1<r}$$ ならば微分値を$${+1}$$とする、というものである。文章で書くと分かりにくいが、この挙動はつまり hardtanh関数と同じである。hardtanhの入力を横軸、出力を縦軸としてグラフに書くと以下の図のようになる:

図 hardtanh の応答グラフ

別の説明としては、これはバイナリ変換関数を恒等写像とみなしているとも言える。(※ちなみに、STEを考案したのは、あのG. Hinton氏である。初出は2012年のCoursera のビデオ講義らしい。)

1.58ビットモデルの仕組み

ここまでのBNNの説明が理解できていれば、1.58ビットトランスフォーマーモデルを理解するのは簡単である。

前述のとおりBNNでは、バイナリ化する対象はDCNNの層であって、その重み行列を伝搬するときに2値化するという仕組みであった。1.58ビットトランスフォーマーでは、それと同じことをトランスフォーマーの入出力部分にある線形写像の重み行列に適用する。

まず順伝搬。BNNあるいはBitNetでは、+1/-1の2値にしていたところを、1.58ビットモデルでは、下式のように RoundClip() という関数で +1/0/-1 の3値にする(式は論文より引用):

ここで、$${W}$$ はトランスフォーマーの埋め込み重み行列である。$${\gamma}$$はその行列の重みの平均値である。 計算は見ての通り特に難しいことはない。

逆伝搬のほうもBNNと全く同じである。STEを使ってそのままでは微分できないRoundClip関数を擬似的に微分する。

仕組みの要点は、たったこれだけである。このある意味単純な仕組みで画期的な効果が生み出せることに驚かされる。

まとめ

LLM分野では、モデルの小型化や学習の高速化手法のトレンドは直近では蒸留やLoRA系の低ランク近似が主流であったが、この1.58ビットトランスフォーマーは大きなゲームチェンジャーになり、低ビットモデルが大流行するかもしれない。

AIの社会実装という意味でも、この成果は大きなブレークスルーになりそうである。家電、自動車、産業機器などなど・・・メーカー各社は一斉に製品戦略の書き換えを始めるだろう。2020年代後半には街中に生成AIを搭載した機器たちが溢れていることになるかもしれない。

<以下は宣伝です>

当社オープンストリームでは、お客様のビジネスでご活用いただくべく、LLMや生成AIを使った研究や技術実証やアプリの試作、ハッカソンなどの活動を実施して知見を蓄えています。ご興味のある方はお気軽に私または当社までお問い合わせください。よろしくお願いいたします。

以上、お読みいただきありがとうございました。


いいなと思ったら応援しよう!