【論文要約】Megalodon: Efficient LLM Pretraining and Inference with Unlimited Context Length【自分用メモ】

イントロダクション

今回は『Megalodon: Efficient LLM Pretraining and Inference with Unlimited Context Length』という以下の論文を要約する。論文のpdfをClaude 3 Opusに渡して要約させた。


研究の目的と背景

<purpose>

本研究の目的は、大規模言語モデル(LLM)の事前学習と推論の効率を改善し、無制限の文脈長に対応できる新しいニューラルアーキテクチャMEGALODONを開発することである。Transformerモデルは、二次計算量と文脈長の制約により、長い系列への拡張が困難であり、線形アテンションや状態空間モデルなどの準二次解法は経験的にTransformerの事前学習効率とダウンストリームタスクの精度に及ばない。MEGALODONは、MEGAアーキテクチャを継承しつつ、複素指数移動平均(CEMA)、タイムステップ正規化層、正規化アテンション機構、2ホップ残差接続を備えたpre-norm構成などの技術的改良により、7Bパラメータと2兆トークンの事前学習規模において、Transformerを上回る効率と性能を達成することを目指している。本研究の意義は、多ターン会話、長文書理解、動画生成など、長い系列データを効率的に処理し、内部の長距離ダイナミクスを理解し、一貫性のある出力を生成する必要のある実世界のアプリケーションにおいて、LLMの実用性を大幅に向上させることにある。MEGALODONの新規性は、チャンク単位のアテンションによる線形計算量と無制限の文脈長対応、およびCEMAやタイムステップ正規化層などの独自の技術的改良点にある。

<background>

大規模言語モデル(LLM)の研究分野では、Transformerアーキテクチャが卓越した性能を示す一方で、系列長に対する二次計算量とアテンション範囲の制約から長距離のモデリングが課題となっている。例えば、7Bパラメータモデルで100万トークンの1ステップ学習を256個の4Kトークン系列で分散実行するのに比べ、単一の100万トークン系列で実行すると100倍以上も遅くなる。この課題に対し、線形アテンションや構造化状態空間モデルなどの準二次計算量手法が提案されてきたが、Transformerに比べて実用的な性能が不十分であった(Tay et al., 2022; Gu and Dao, 2023)。MEGAアーキテクチャ(Ma et al., 2023)は、ゲートアテンションと指数移動平均(EMA)を組み合わせることで、チャンク単位のアテンションによる線形計算量を実現したが、チャンク間の文脈損失や異なるタスク・データ型に対するアーキテクチャの不統一などの問題があり、大規模事前学習への拡張可能性は未検証であった。本研究のMEGALODONは、MEGAを基に複素EMA(CEMA)、タイムステップ正規化層、正規化アテンション、2ホップ残差pre-norm構成などの独自改良を加えることで、7Bパラメータ・2兆トークン規模のLLM事前学習において、データ効率と計算効率の両面でTransformerアーキテクチャを上回ることを実証する。

<note>専門用語の説明:

  • 大規模言語モデル(Large Language Model; LLM): 数十億から数千億のパラメータを持つ自然言語処理のための深層学習モデル。膨大なテキストデータによる事前学習と比較的少量のタスク固有データによるファインチューニングにより汎用的な言語理解・生成能力を獲得する。

  • 二次計算量: 計算時間やメモリ使用量が入力系列長のn個のトークンに対してO(n^2)のオーダーで増大すること。アテンション機構がn×nの行列計算を含むため、Transformerは系列長に対して二次計算量を要する。

  • チャンクアテンション: 系列を固定長のチャンクに分割し、各チャンク内でのみアテンションを計算する手法。チャンク長をcとすると、計算量はO(nc)に削減される。

MEGAアーキテクチャについて

MEGAアーキテクチャは、Transformerの注意機構における弱い帰納バイアスと二次計算量という課題を解決するために提案された手法です。主要な改良点は以下の通りです。

  1. 指数移動平均(EMA)の導入
    MEGAは時系列方向に指数関数的に減衰する局所的依存関係を取り入れるために、多次元減衰EMAを注意機構に組み込んでいます。EMAにより、位置に依存しない注意機構に位置を考慮した局所的な依存関係のバイアスを効果的に組み込むことができます。

  2. ゲート付き注意機構
    MEGAは、EMAの出力から共有表現を計算し、リセットゲートとアップデートゲートを備えたゲート付き注意機構を採用しています。単一ヘッドのゲート付き注意機構が、理論的にマルチヘッド注意機構と同等の表現力を持つことが示されています。

  3. チャンクアテンション
    MEGAは入力系列を固定長のチャンクに分割し、各チャンク内でEMAとゲート付きアテンションを適用することで、計算量を系列長に対して線形に抑えつつ、チャンク間の文脈を考慮できるMEGA-chunkという亜種を提案しています。これにより、文脈長に制限のない効率的な学習が可能になります。

  4. ラプラス関数に基づく注意機構
    頑健性と安定性を向上させるため、MEGAではソフトマックス関数の代わりにラプラス関数に基づく注意機構を提案しています。二乗ReLU関数よりも安定した学習が可能であることが実験的に示されています。

  5. 位置エンコーディング
    MEGAでは、回転式位置エンコーディング(RoPE)を用いることで、学習時よりも長い文脈での推論を可能にしています。

以上のように、MEGAはTransformerの課題を、EMA、ゲート機構、チャンクアテンション、ラプラス関数、位置エンコーディングなどの改良により解決し、効率的かつ頑健な長距離依存関係のモデル化を実現しています。長文書理解や会話モデルなど、大規模言語モデルの実用性向上に貢献することが期待されます。

使用した手法の概要

<methods>

MEGALODONは、MEGAアーキテクチャを基盤とし、複数の新しい技術要素を導入することで、大規模言語モデルの効率的な事前学習と推論を実現している。

まず、MEGAの指数移動平均(EMA)の多次元減衰化を複素数領域に拡張した複素EMA(CEMA)を提案している。CEMAは、実数のパラメータに加えて、複素数のηと基底角ωを導入し、次元ごとに異なる複素数での指数移動平均を行う。これにより、カーネル重みの減衰構造を保持しつつ、EMAの表現力を高めている。
次に、オートリグレッシブな系列モデリングにおけるタイムステップ方向の内部共変量シフトを低減するタイムステップ正規化層を提案している。これは、Group Normalizationをオートリグレッシブ設定に拡張したもので、特徴次元のグループ内と系列方向の2次元で累積平均・分散を計算し、正規化を行う。現代のハードウェア(GPU)上で効率的かつ安定的に累積統計量を計算する実装も提供している。
さらに、MEGAのゲートアテンションに特化した正規化アテンション機構を導入している。アテンション行列をソフトマックス関数の手前で正規化することで、飽和や不安定性の問題を緩和し、大規模学習の安定性を高めている。正規化アテンションにより、異なるタスクやデータ型に対して統一的なアーキテクチャを適用できるようになった。
加えて、Pre-Norm構成の不安定性を改善する2ホップ残差Pre-Norm構成を提案している。これは、Pre-Normの残差接続を単純に付け替えることで実現しており、アテンション層の前にタイムステップ正規化層、FFN層の前に層正規化を適用している。ゲート付き残差接続を削除しても安定的に学習できるようになった。
最後に、チャンクアテンションの並列化により、データ・テンソル・パイプラインの3次元に直交する新たな系列方向の並列化を実現している。各チャンク並列グループ間の通信はCEMAの最終隠れ状態とタイムステップ正規化の累積統計量のみであり、非同期通信により他の計算と重複させることで、オーバーヘッドを最小化している。

以上の手法を7Bパラメータ・2兆トークン規模のMEGALODONモデルに適用し、事前学習データとハイパーパラメータをLLAMA2と揃えて制御された比較実験を行うことで、データ効率と計算効率の優位性を実証している。

<comparison>

MEGALODONは、EMAを基盤としてゲートアテンションを導入したMEGAアーキテクチャを継承しつつ、CEMAによるEMAの表現力向上、タイムステップ正規化層による系列方向の統計量を考慮した正規化、正規化アテンションによる大規模学習の安定化、2ホップ残差Pre-Norm構成によるアーキテクチャの統一化と単純化、チャンクアテンションの並列化による計算効率の改善など、複数の新規手法を組み合わせることで、MEGAの問題点を解決し、Transformerを凌駕する性能を達成している。
特に、CEMAとタイムステップ正規化層の組み合わせにより、チャンク間の文脈損失を大幅に低減し、チャンクアテンションとフルアテンションの性能差を縮めている点が特徴的である。
また、Pre-Norm構成の残差接続の付け替えという単純な工夫により、ゲート機構などのパラメータを増やすことなく、深いモデルの学習を安定化できている。
さらに、4次元の並列化手法により、アテンション範囲を制限しつつ大規模な事前学習を高速に実行できるようになっている。

<note> 専門用語の説明:

  • 指数移動平均(Exponential Moving Average; EMA): 時系列データの平滑化手法の一種で、過去の値に指数関数的に減衰する重みを掛けて平均を取る。パラメータは減衰率を制御する。

  • 内部共変量シフト(Internal Covariate Shift): 深層ニューラルネットワークの学習において、各層への入力分布が学習中に変化する現象。勾配の不安定化や消失・爆発を招く。

  • Group Normalization: チャネル方向をグループに分割し、グループ内で正規化する手法。バッチサイズに依存せず、チャネル数が多い場合に有効。

  • 正規化アテンション(Normalized Attention): アテンション行列を正規化することで、飽和や不安定性を緩和する手法。正規化手法としてはコサイン類似度やQK正規化などがある。

  • Pre-Norm: 各層でアテンション・FFNの前に正規化を適用する構成。オリジナルのTransformerはPost-Norm(演算後に正規化)を採用していたが、学習の安定性と収束性からPre-Normが主流となった。

論文内の数式と手法

<equations>

MEGALODONで提案されている全ての数式とその役割は以下の通りである。

オリジナルのEMAの更新則を定義する数式(1):

$$
u_t^{(j)} = \beta_j x_{t,j} \\
h_t^{(j)} = \alpha_j \odot u_t^{(j)} + (1 - \alpha_j \odot \delta_j) \odot h_{t-1}^{(j)} \\
y_{t,j} = \eta_j^T h_t^{(j)}
$$

この数式は、EMAの基本的な動作を表している。入力$${x_{t,j}}$$に拡張行列$${\beta_j}$$を掛けて$${h}$$次元に拡張し、前の隠れ状態$${h_{t-1}^{(j)}
$$と減衰率$${\alpha_j, \delta_j}$$で加重平均することで、新しい隠れ状態$${h_t^{(j)}}$$を更新する。$${h_t^{(j)}}$$に射影行列$${\eta_j}$$を掛けて、出力$${y_{t,j}}$$を得る。

複素EMA(CEMA)の更新則を定義する数式(2)と(3):

$$
h_t^{(j)} = \alpha_j (\cos \theta_j + i \sin \theta_j) \odot u_t^{(j)} + (1 - \alpha_j \odot \delta_j)(\cos \theta_j + i \sin \theta_j) \odot h_{t-1}^{(j)} \\
y_{t,j} = \mathrm{Re}(n^T_jh^{(j)}_t) \\
\theta_{j,k} = \frac{2\pi k}{h}\omega_j, \forall k \in {1, 2, ..., h}
$$

これらの数式は、CEMAの隠れ状態を複素数の指数移動平均で更新する役割を持つ。式(2)は、入力$${u_t^{(j)}}$$と前の隠れ状態$${h_{t-1}^{(j)}}$$に複素数の減衰率$${\alpha_j (\cos \theta_j + i \sin \theta_j)}$$と$${(1 - \alpha_j \odot \delta_j)(\cos \theta_j + i \sin \theta_j)}$$を掛けて加重平均する。式(3)は、複素数の位相$${\theta_j}$$を基底角$${\omega_j}$$から生成する。

タイムステップ正規化の累積統計量を定義する数式(4):

$$
\mu_t = \frac{1}{t * d_g} \sum_{i=1}^t \sum_{j=1}^{d_g} x_{i,j},\\
\sigma_t^2 = \frac{1}{t * d_g} \sum_{i=1}^t \sum_{j=1}^{d_g} (x_{i,j} - \mu_t)^2
$$

この数式は、特徴次元をグループ化し、各グループ内で系列方向に累積平均$${\mu_t}$$と累積分散$${\sigma_t^2}$$を計算する。$${d_g}$$はグループ内の次元数、$${t}$$は系列の長さを表す。

正規化アテンションの各処理を定義する数式(5)から(9):

$$
X' = \mathrm{CEMA}(X) \in \mathbb{R}^{n \times d} \\
Z = X'W_z + b_z, Z' = \frac{Z}{||Z||} \in \mathbb{R}^{n \times z} \\
Q = \kappa_q \odot Z' + \mu_q \in \mathbb{R}^{n \times z} \\
K = \kappa_k \odot Z' + \mu_k \in \mathbb{R}^{n \times z} \\
O = f_\mathrm{softmax}(QK^T)V \in \mathbb{R}^{n \times v}
$$

これらの数式は、CEMAの出力$${X'}$$を正規化し、正規化表現$${Z'}$$からクエリ$${Q}$$とキー$${K}$$を導出し、アテンション行列$${QK^T}$$をソフトマックス関数$${f_\mathrm{softmax}}$$で正規化し、値$${V}$$との加重和を出力$${O}$$とする一連の処理を表現している。

通常のPre-Normの残差接続を定義する数式(10):

$$ \hat{Y} = \mathrm{Attention}(\mathrm{Norm}(X)) + X \\
Y = \mathrm{FFN}(\mathrm{Norm}(\hat{Y})) + \hat{Y} \\
= \mathrm{FFN}(\mathrm{Norm}(\hat{Y})) + \mathrm{Attention}(\mathrm{Norm}(X)) + X $$

この数式は、Pre-Normの標準的な実装を表している。アテンション層とFFN層の前にそれぞれ正規化を適用し、各層の出力にSkip Connectionで入力を加算する。式(10)の下段が示すように、最終出力$${Y}$$は、アテンション層とFFN層の出力、入力$${X}$$の3つの項の和になる。

Pre-Normの2ホップ残差接続を定義する数式(11):

$$
\hat{Y} = \mathrm{Attention}(\mathrm{Norm}(X)) + X \\
Y = \mathrm{FFN}(\mathrm{Norm}(\hat{Y})) + X
$$

この数式は、アテンション層の前のPre-Norm出力と入力$${X}$$の和$${\hat{Y}}$$をFFN層の入力とし、その出力にも$${X}$$を加えることで、残差接続を2ホップ分共有する構成を表している。

<derivation>

CEMAの数式(2)(3)は、EMAの数式(1)を複素数に一般化したものである。複素数の指数関数$${e^{i\theta} = \cos \theta + i \sin \theta}$$の性質を利用し、実数の減衰率$${\alpha, \delta}$$と複素数の位相$${\theta}$$を組み合わせることで、振幅と位相の両方を制御可能なEMAを実現している。$${h}$$次元の複素数を用いることで、$${h}$$個の周波数成分を表現でき、時系列の周期性とトレンドを捉えられる。θの値を式(3)のように$${\omega}$$から決定することで、$${h}$$個の基底角を等間隔に配置し、パラメータ数を削減している。

タイムステップ正規化の数式(4)は、バッチ正規化やレイヤー正規化を系列方向に拡張したものと見なせる。バッチ方向の統計量では、異なる系列長に対応できないため、系列方向に累積的に統計量を計算する。グループ正規化のように特徴次元を分割することで、次元数が大きい場合でも安定的に統計量を推定できる。オートリグレッシブ生成では、未来の情報を使えないため、累積統計量を用いる。

正規化アテンションの数式(5)-(9)は、被除数と除数のペアに着目し、大きさの比から類似度を定義する一般的な正規化の考え方に基づいている。通常のスケール付きドット積アテンションでは、$${\frac{QK^T}{\sqrt{d_k}}}$$の形式でスケール項による大きさの調整が必要だが、$${Q}$$と$${K}$$を予め正規化することで除算が不要になる。また、$${Q}$$と$${K}$$を同じ正規化表現$${Z'}$$から線形変換で求めることで、類似度の対称性や値の安定性が向上する。

Pre-Normの2ホップ残差接続の数式(11)は、Pre-Normのスキップ接続を再利用するシンプルなアイデアである。通常のPre-Normでは、各層でスキップ接続が新たに作られ、ネットワークが深くなるほど出力の分散が増大してしまう。それを防ぐために、アテンション層とFFN層で同じスキップ接続を共有することで、各層の入力分布を安定化させる効果が期待できる。この2ホップ構成は、Pre-Normの拡張として自然であり、原理的には3ホップ以上に拡張することも可能である。

得られた主な結果

<main_results>

MEGALODONは、大規模言語モデルの事前学習において、データ効率と計算効率の両面でTransformerアーキテクチャを上回る性能を達成した。Figure 1に示すように、MEGALODON-7BはLLAMA2-7Bと比較して、同じ学習トークン数に対して有意に低いパープレキシティ(NLL)を実現し、学習の終盤ではLLAMA2-13Bに匹敵する性能に到達した。これは、CEMAによる表現力の向上と、タイムステップ正規化層によるチャンク間の文脈損失の低減が効果的に働いたためと考えられる。また、Figure 4が示すように、32Kの長い文脈長においても、MEGALODON-7BはLLAMA2-7Bよりも32%高速に学習できた。これは、チャンクアテンションの並列化による計算効率の改善の結果である。 さらに、Table 1に示すように、MEGALODON-7Bは、MMSLUや言語理解、質問応答などの標準的なベンチマークにおいて、LLAMA2-7Bを上回る性能を示し、同規模の他の公開モデルと比較しても競争力のある結果を達成した。特に、Arc-eやTriviaQAでは、LLAMA2-13Bをも凌駕する高い精度を実現した。 無制限の文脈長を扱える能力については、Figure 5に示すように、検証用のデータセットに対して、文脈長を4Kから2Mまで変化させてパープレキシティを評価した結果、文脈長の増加に伴ってパープレキシティが単調に減少することを確認した。 また、Table 2に示すように、Scrollsデータセットの長文脈質問応答タスクにおいて、MEGALODON-7Bは、他の公開モデルを上回る高いF1スコアを達成し、特にNarrativeQAでは最高性能を記録した。この結果は、MEGALODONが長距離の依存関係を適切に捉えられていることを示唆している。

<details>

Figure 1のNLLは、2兆トークンの学習データに対する言語モデルの対数尤度の負値であり、モデルの予測性能を測る指標である。MEGALODON-7BのNLLは学習の終盤で1.70に到達し、LLAMA2-7Bの1.75を上回り、LLAMA2-13Bの1.67に迫る結果となった。ただし、学習の初期段階(500B以下)では、LLAMA2-7Bの方がわずかに優れていた。この原因としては、位置エンコーディングの初期化の違いが考えられる。 Figure 4は、NVIDIA A100 GPU 256台を用いて4Mトークンのバッチサイズで学習した際の1デバイス当たりの平均単語処理速度(WPS)を比較したものである。4Kの文脈長では、MEGALODON-7Bはトランスフォーマーの演算を最適化したLLAMA2-7Bよりも6%遅かったが、32Kの文脈長ではMEGALODON-7Bが32%高速化された。ただし、チャンク並列化のオーバーヘッドにより、MEGALODON-7B-32KはMEGALODON-7B-4Kの94%の効率にとどまった。 Table 1のベンチマークスコアは、標準的なデータセットに対する0-shot/5-shotでの性能であり、数値が大きいほど高精度であることを示している。全8つのタスクのうち、MMSLUとBoolQを除く6つでMEGALODON-7BがLLAMA2-7Bを上回った。ただし、Mistral-7BやGemma-8Bなど、より大規模なデータで学習されたモデルには及ばなかった。 Figure 5のパープレキシティは、検証用の書籍データに対して、文脈長を変えながら次トークンの予測確率を評価した値である。パープレキシティは、予測確率の逆数を文脈長で正規化した指標であり、小さいほど予測性能が高いことを意味する。MEGALODON-7Bは、4Kから2Mまでの全ての文脈長において、文脈が長いほどパープレキシティが減少する傾向を示した。 Table 2のF1スコアは、Scrollsデータセットの各タスクに対する予測と正解の一致度を表す。Xgen-7B-8KやLLAMA2-7B-4Kに比べ、MEGALODON-7Bは全てのタスクで高いスコアを記録した。32Kの文脈長に拡張したLLAMA2-7B-32Kには及ばないものの、NarrativeQAでは最高性能を達成するなど、MEGALODONの優位性が示された。 </details> <comparison> MEGALODON-7Bは、LLAMA2-7Bをはじめとする同規模のTransformerモデルと比較して、データ効率、計算効率、モデルの性能の全ての観点で優れた結果を示した。 Figure 1とTable 1に示されるように、学習データとハイパーパラメータを揃えて比較した場合、MEGALODON-7Bは一貫してLLAMA2-7Bを上回っており、その差は統計的にも有意であった(例えばArc-eで4.2ポイント、NQで5.5ポイントの差)。 Figure 4が示すように、32Kの長い文脈長において、MEGALODON-7Bの計算速度はLLAMA2-7Bを32%上回っており、その差は明確であった。この結果は、チャンクアテンションの導入による効果と考えられる。 Figure 5とTable 2の結果から、MEGALODONは文脈長に関わらず安定した性能を発揮できることが示された。特に、Figure 5の単調減少傾向は、MEGALODONが長距離依存関係を適切にモデル化できていることを強く示唆している。 一方で、学習データ量が多いMistral-7BやGemma-8Bとの比較では、MEGALODONの優位性は限定的であり、データスケールとモデルスケールのトレードオフについては慎重な評価が必要である。

この記事が気に入ったらサポートをしてみませんか?