LLMファインチューニングのためのNLPと深層学習入門 #11 バッチ正規化(1)

今回はレイヤー正規化・・・と思ったんですが、レイヤー正規化がバッチ正規化から派生したことを考えると、バッチ正規化を先に知っておいた方が良いのかなと思ったので、今回はCVMLエキスパートガイドより、『バッチ正規化(Batch Normalization) とその発展型』を勉強していきます。

ちょっと自分にはカロリー高かったので、前半部分と後半部分に分けて読み進めていきます。
具体的には、バッチ正規化の計算手順までです。


1. バッチ正規化とは

バッチ正規化(Batch Normalization)とは、畳み込みニューラルネットワーク(CNN)の隠れ層において、ミニバッチ内のデータ分布をもとに各チャンネルごとに特徴を正規化したのち、スケール・シフトを行うことで、学習の高速化と安定化を図る層です。

図1 . バッチ正規化 (batch normalization) [物体認識CNN]
CVMLエキスパートガイドより

バッチ正規化を各中間層で行うことで、もとのDNNの表現力の高さを保ちつつも、学習の収束の高速化と安定化を達成できるうえに、正規化の役割も果たすことができます。

バッチ正規化は以下に示す3つの発展型があります。

  1. レイヤー正規化
    Transformer系モデルでよく使用される系列データ向けの改善

  2. インスタンス正規化
    GANでの画像生成や、画像スタイル変換向けの改善

  3. グループ正規化
    小バッチ・多タスク学習向けの改善

これらの発展型の登場により、バッチ正規化が応用されるタスク(もといネットワーク構成)が増えていったとも言えます。

バッチ正規化や上記の発展型の各層は、特徴マップや特徴ベクトルの正規化効果により、ResNetやTransformerなどの「層の多いDNN」を安定して最適解近くまで学習するうえで必須の技術です。

また、バッチ学習である確率的勾配降下法(Stochastic Gradient Descent, SGD)においても、非常に助けになる重要技術です。

『SGD』
機械学習における最適化アルゴリズムの一種で、特に深層学習(Deep Learning)などのニューラルネットワークの学習によく使用されます。

確率的勾配降下法は、全データセットを使って勾配を計算する通常の勾配降下法とは異なり、ランダムに選ばれたサブセット(ミニバッチ)のデータを使って勾配を近似的に計算します。これにより、一度の学習ステップで使うデータ量を減らすことができ、大規模なデータセットに対しても高速に学習を進めることが可能になります。

また、確率的勾配降下法は学習過程にランダム性を導入するため、局所最適解に捕まるリスクを減らす効果もあります。これにより、より良い汎化性能を持つモデルを学習することができます。

文中にある「バッチ学習であるSGD」とは、この確率的勾配降下法を用いた学習方法を指しています。バッチ正規化などの技術は、SGDのような最適化手法がモデルのパラメータをより良い方向に更新できるように、各層の入力分布を整える役割を果たします。

バッチ正規化は、[Ioffe and Szegedy, 2015]で提案されてから、CNNやDNNの中間層として標準的に用いられる層となりました。
それ以前のCNNでは、以下のように3種の層からなる1ブロックを繰り返す設計が標準的でした。

「畳み込み層 → ReLU → プーリング層」

それが、バッチ正規化が登場して以降、次に示すように4種の層で1ブロックを構成する設計がCNNでは標準的となりました。

「畳み込み層 → バッチ正規化 → ReLU → プーリング層」

つまり、CNNブロック内に、バッチ正規化も定番の層として新たに加わった、ということです。

2. 大体の処理内容

バッチ正規化層を用いた学習では「畳み込み層 → ReLU層」の出力である活性化後の特徴マップのうち、それぞれのチャンネルについて、正規化された固定分布を学習中に求めます。
ここで、$${k}$$チャンネル目の出力$${x^{(k)}}$$とします。

図2. バッチ正規化の処理. バッチ内全体で,チャンネル単位の正規化を行う.
CVMLエキスパートガイドより

SGD学習中において、バッチ正規化の処理では、N個のサンプルから構成されるミニバッチ$${X=x_1, x_2, …, x_n, …, x_N}$$のうち、同一チャンネルに含まれる$${m}$$個の特徴ベクトル$${B=\{x_i^{(k)}\}_{i=1}^N}$$に対して、以下の処理を行います。

  1. バッチ内で各特長$${x_i^{(k)}}$$を正規化し、単位ベクトル$${\hat{x}}$$を得ます。
    つまり、平均0、スケール1の分布にします。

  2. 正規化済みの$${\hat{x}}^{(k)}}$$をスケーリング・シフトして、固定分布$${y^{(k)}=\lambda^{(k)}\hat{x}^{(k)}+\beta^{(k)}}$$へと変換します。

重要な点として、これらの$${\lambda^{(k)}}$$と$${\beta^{(k)}}$$は学習可能なパラメータであり、学習過程中に更新されます。

手順1のみだと全チャンネルが同じ分布になり、CNNの表現が落ちてしまいますが、手順2で学習可能なパラメータを導入し、それらを用いてスケーリングとシフトを行うことで、異なる特徴を持つ各チャンネルが独自のスケールとバイアスを持つことが可能になります。
そのため、ネットワークの表現力を維持することが可能になります。

『スケーリング・シフト』
正規化されたデータに対してスケーリング(スケール変換)とシフト(平行移動)を行うという操作のことを指します。

今回は、スケーリングとは正規化されたデータ$${\hat{x}}^{(k)}$$にスケールパラメータ$${\lambda}^{(k)}$$を掛けることで、データの尺度(スケール)を変更する操作を指します。一方、シフトとはその結果にシフトパラメータ$${\beta}^{(k)}$$を足すことで、データをある値だけ平行移動する操作を指します。

バッチ正規化は、各中間層の出力を特定の「学習済みの固定分布」に変換するため、画像認識向けのCNNやTransformer系のネットワークなどを学習する際に、SGD学習の安定化と高速化が見込めます。

また、バッチ正規化が導入されて以来、深い層をもつCNNを学習する際に、学習率を高めに設定しても、勾配爆発や勾配消失問題を起こさないように学習を安定させることができました。
これにより、早い時間での学習の収束を実現しやすくなりました。

さらに、バッチ正規化はモデルの正則化効果もあるので(後述)、バッチ正規化を導入することによりモデルの汎化性能、つまり未知のデータに対する予測性能が向上します。(ただし、Transformer登場以降の巨大なDNNではドロップアウトが良く使用されています。)

3. バッチ正規化の詳細

3.1 計算手順

先ほども示した通り、バッチ正規化層は学習可能なパラメータを持っているので、(A)学習時と(B)テスト時で少し異なる計算を行います。

3.1.1(A)学習時の計算

  • 線形層(全結合層 or 畳み込み層)の出力応答のc番目の出力$${x^{(c)}}$$に対して、1バッチ内で正規化したあと、スケーリング・シフトを行います。

  • SGD中に「スケール係数・シフト係数」を同時に学習することで、チャンネル内の出力$${x^{(c)}}$$は「学習済みの固定分布」から出力されるようになっていきます。(全てのパラメータが最適化され、それ以上変化しなくなった時点の分布が、「学習済みの固定分布」を意味します。)

  • 推論時には、その固定分布から出力します。

詳しい説明

以下に、学習時のチャンネル$${c}$$での、バッチ正規化層の出力である$${BN_{\gamma, \beta}(x_i)}$$の計算手順を示します。

  • まず、チャンネル$${c}$$内のバッチ平均$${\mu_B}$$を計算します。
    ここで、mは同一チャンネルに含まれる特徴ベクトルの個数、$${x_i}$$はバッチ正規化が適用される前の、特定のチャンネル$${c}$$における$${i}$$番目の入力データを表します。

$$
\mu_B \leftarrow \frac{1}{m} \sum_{i=1}^mx_i
$$

  • そして、チャンネル$${c}$$内のバッチ分散$${\sigma_B^2}$$を計算します。

$$
\sigma_B^2 \leftarrow \frac{1}{m} \sum_{i=1}^{m}(x_i - \mu_B)^2
$$

  • 1つのミニバッチ内の$${N}$$個の各特徴マップに対して、それぞれ正規化とスケーリング・シフトを行います。
    各マップ内内の各特徴量$${x_i}$$について1,2の順で計算します。

    Nとmについて再度整理すると、Nはミニバッチ内の特徴マップの数を表しており、それぞれの特徴マップ(またはチャネル)内には、m個の特徴量(またはピクセル)が存在しているということです。

1.(チャンネルcでの)正規化

$$
\hat{x_i} \leftarrow \frac{1}{\sqrt{\sigma_B^2+\epsilon}}(x_i - \mu_B)
$$

ここで、$${\epcilon}$$はゼロ除算を防ぐための定数です。

2. スケーリングとシフト

$$
y_i \leftarrow \gamma \hat{x_i} + \beta \equiv BN_{\gamma, \beta}(x_i)
$$

ここでは、バッチ正規化の操作$${BN_{\gamma, \beta}(x_i)}$$を$${y_i}$$と定義し、これはスケーリング・シフトの結果$${\gamma \hat{x_i} + \beta}$$と一致する、という意味を持ちます。

  • スケーリングとシフトに使用されるパラメータ$${\gamma}$$と$${\beta}$$は、ネットワークが訓練データに対して学習を進める中で、誤差逆伝播法と確率的勾配降下法(SGD)によって更新されます。
    具体的には、ネットワークの予測結果と正解との間の誤差を最小化するように、これらのパラメータが調整されます。

  • CNNの学習が終了すると、バッチ正規化層のパラメータの学習も終了します。

このアルゴリズムにより、各cチャンネルのニューロン出力値$${x_i}$$は、どれもバッチ内の正規化と$${\gamma}$$, $${\beta}$$で表現される固定分布へと少しずつ収束していきます。

上記で示した処理はc次元目の特徴$${x_1, x_2, …, x_{N \times W \times H}}$$の全てにおいて行われるため、学習が終了するとパラメータ$${\{\gamma^{(c)}, \beta^{(c)}\}_{c=1}^C}$$が手に入ります。

学習中に、各チャンネル$${c}$$での分布のスケール量$${\gamma^{(c)}}$$とシフト量$${\beta^{(c)}}$$がそれぞれ独立で固定されていきます。
それに伴い、全チャンネルの隠れ層の出力分布がそれぞれ固定分布に安定し、ニューラルネットワークの各層の入力と出力が一定の範囲に収まることにより「効率的な勾配降下=早くて安定した学習」が可能になります。 

バッチ正規化の挿入箇所

線形層(全結合層または畳み込み層)では、バッチ正規化を挿入しない場合、つぎのように入力xがアフィン変換(ベクトル空間において線形変換と平行移動を組み合わせること)されます。

$$
h = f(W^Tx+b)
$$

ここで、$${f(\cdot)}$$はReLUなどの活性化関数、$${W}$$が結合重み行列、$${b}$$がバイアスベクトルです。

『結合重み行列』
ニューラルネットワークの特定の層から次の層への接続の重みを表現する行列。これらの重みは学習中に更新され、最終的な出力の計算に対して各入力特徴がどれくらい影響を及ぼすかを決定します。

『バイアスベクトル』
各ニューロンがどれだけ容易に活性化(つまり出力が0以上になること)するかを制御します。バイアスは各ニューロンに対して1つずつ存在し、その値は学習中に更新されます。ニューロンの総入力は、重み付けされた入力の和にバイアスが加算されたものとなります。バイアスの存在により、ニューロンの活性化閾値を0から別の値にシフトすることが可能になります。

線形層の出力を活性化関数に通す前にバッチ正規化の変換$${BN_{\sigma, \gamma}(x)}$$を挿入します。
このとき、以下のような変換式になります。

$$
h=f(BN_{\sigma, \gamma}(W^Tx))
$$

ここでは、まず入力$${x}$$に結合重み$${W}$$を適用し($${W^Tx}$$)、次にその結果にバッチ正規化を適用します($${BN_{\sigma, \gamma}(W^Tx)}$$)。最後にその結果を活性化関数$${f(\cdot)}$$に通すという過程を示しています。

このような流れに従うと、バイアス項$${b}$$は不要となります。なぜなら、バッチ正規化の変換の一部である$${\beta}$$がバイアスと同じ機能を持つからです。
以上が学習時のバッチ正規化処理です。

3.1.2(B)テスト時の計算

学習時にはランダムなミニバッチ内を毎回正規化して、その各隠れ層の暫定的な正規化結果によってチャンネル$${c}$$ごとにパラメータ$${\gamma^{(c)}, \beta^{(c)}}$$を学習していました。

しかし、テスト時にはミニバッチを得ることはできないため、データセット全体から決定しておいた平均・分散を常に利用して$${BN_{\sigma, \gamma}(x_i)}$$を実施する必要があります。

そこで、まず学習時に、学習データ全体のチャンネル$${c}$$におけるバッチ平均$${E^{(c)}[x]}$$とその分散$${Var^{(c)}[x]}$$を、それぞれ算出します。

$$
E^{(c)}[x] \leftarrow E^{(c)}_B[\mu_B]
$$

$$
Var^{(c)}[x] \leftarrow \frac{m}{m-1}E_B^{(c)}[\sigma_B^2]
$$

次に、テスト時はこれらを用いて出力$${y^{(c)}=BN_{\gamma^{(c)}, \beta^{(c)}}(x^{(c)})}$$を以下のように計算します。

$$
y^{(c)}=\frac{\gamma ^ {(c)}}{\sqrt{Var^{(c)}[x]+\epsilon}} \cdot x^{(c)} + (\beta^{(c)} - \frac{\gamma E^{(c)}[x]}{\sqrt{Var^{(c)}[x]+\epsilon}})
$$

この式は、学習フェーズで学習されたパラメータ(スケールパラメータ $${\gamma^{(c)}}$$ とシフトパラメータ $${\beta^{(c)}}$$)と、全ての学習データを通じて算出した平均 $${E^{(c)}[x]}$$ と分散 $${Var^{(c)}[x]}$$ を使用して、入力データ $${x^{(c)}}$$ を正規化します。

この正規化のプロセスは次のように解釈できます。

  1. $${\frac{\gamma ^ {(c)}}{\sqrt{Var^{(c)}[x]+\epsilon}} \cdot x^{(c)}}$$:
    まず、入力データ $${x^{(c)}}$$ をスケール調整します。
    これは入力データを標準偏差(ルート分散)で割り、スケールパラメータ $${\gamma^{(c)}}$$ を乗じます。
    ここで、$${\epsilon}$$ はゼロ除算を防ぐための微小な数値です。

  2. $${\beta^{(c)} - \frac{\gamma E^{(c)}[x]}{\sqrt{Var^{(c)}[x]+\epsilon}}}$$:
    次に、シフト調整を行います。
    これはシフトパラメータ $${\beta^{(c)}}$$ から、平均値をスケール調整した値を引くことで行います。
    この部分の結果は、新たな入力データが平均値と一致する場合に $${\beta^{(c)}}$$ となるように調整されます。

以上の2つの部分を合わせることで、新たな入力データ $${x^{(c)}}$$ が平均値と分散によって正規化され、それにスケール調整とシフト調整が加えられた出力 $${y^{(c)}}$$ が得られます。

こうすることで、テスト時はデータセット全体の統計を用いてのバッチ正規化を実行できます。

4. おわりに

今回はバッチ正規化について学びました。
ちょっと自分にはカロリー高かったので、前半と後半に分けようと思います。

結構詳しいところまで学びましたが、実際LLMのファインチューニングをする程度でしか使わないのなら、バッチ正規化がどういうものなのか、どんな効果があるのかだけ知っていればいい気もします。

次回はレイヤー正規化について残りの部分を勉強していきます。
それでは。

進捗上げてます

「#AIアイネス」で日々の作業内容を更新しています。
ぜひ覗いてみて下さい。

参考

更新履歴

  • タイトルの「レイヤー正規化」という表記は、「バッチ正規化」の間違いです。
    元々レイヤー正規化について勉強する予定だったところを、途中でレイヤー正規化の発展元であるバッチ正規化にスケジュールを変更した際に、タイトルを変え忘れていました。


いいなと思ったら応援しよう!