LLMファインチューニングのためのNLPと深層学習入門 #11 バッチ正規化(1)
今回はレイヤー正規化・・・と思ったんですが、レイヤー正規化がバッチ正規化から派生したことを考えると、バッチ正規化を先に知っておいた方が良いのかなと思ったので、今回はCVMLエキスパートガイドより、『バッチ正規化(Batch Normalization) とその発展型』を勉強していきます。
ちょっと自分にはカロリー高かったので、前半部分と後半部分に分けて読み進めていきます。
具体的には、バッチ正規化の計算手順までです。
1. バッチ正規化とは
バッチ正規化(Batch Normalization)とは、畳み込みニューラルネットワーク(CNN)の隠れ層において、ミニバッチ内のデータ分布をもとに各チャンネルごとに特徴を正規化したのち、スケール・シフトを行うことで、学習の高速化と安定化を図る層です。
バッチ正規化を各中間層で行うことで、もとのDNNの表現力の高さを保ちつつも、学習の収束の高速化と安定化を達成できるうえに、正規化の役割も果たすことができます。
バッチ正規化は以下に示す3つの発展型があります。
レイヤー正規化
Transformer系モデルでよく使用される系列データ向けの改善インスタンス正規化
GANでの画像生成や、画像スタイル変換向けの改善グループ正規化
小バッチ・多タスク学習向けの改善
これらの発展型の登場により、バッチ正規化が応用されるタスク(もといネットワーク構成)が増えていったとも言えます。
バッチ正規化や上記の発展型の各層は、特徴マップや特徴ベクトルの正規化効果により、ResNetやTransformerなどの「層の多いDNN」を安定して最適解近くまで学習するうえで必須の技術です。
また、バッチ学習である確率的勾配降下法(Stochastic Gradient Descent, SGD)においても、非常に助けになる重要技術です。
バッチ正規化は、[Ioffe and Szegedy, 2015]で提案されてから、CNNやDNNの中間層として標準的に用いられる層となりました。
それ以前のCNNでは、以下のように3種の層からなる1ブロックを繰り返す設計が標準的でした。
「畳み込み層 → ReLU → プーリング層」
それが、バッチ正規化が登場して以降、次に示すように4種の層で1ブロックを構成する設計がCNNでは標準的となりました。
「畳み込み層 → バッチ正規化 → ReLU → プーリング層」
つまり、CNNブロック内に、バッチ正規化も定番の層として新たに加わった、ということです。
2. 大体の処理内容
バッチ正規化層を用いた学習では「畳み込み層 → ReLU層」の出力である活性化後の特徴マップのうち、それぞれのチャンネルについて、正規化された固定分布を学習中に求めます。
ここで、$${k}$$チャンネル目の出力$${x^{(k)}}$$とします。
SGD学習中において、バッチ正規化の処理では、N個のサンプルから構成されるミニバッチ$${X=x_1, x_2, …, x_n, …, x_N}$$のうち、同一チャンネルに含まれる$${m}$$個の特徴ベクトル$${B=\{x_i^{(k)}\}_{i=1}^N}$$に対して、以下の処理を行います。
バッチ内で各特長$${x_i^{(k)}}$$を正規化し、単位ベクトル$${\hat{x}}$$を得ます。
つまり、平均0、スケール1の分布にします。正規化済みの$${\hat{x}}^{(k)}}$$をスケーリング・シフトして、固定分布$${y^{(k)}=\lambda^{(k)}\hat{x}^{(k)}+\beta^{(k)}}$$へと変換します。
重要な点として、これらの$${\lambda^{(k)}}$$と$${\beta^{(k)}}$$は学習可能なパラメータであり、学習過程中に更新されます。
手順1のみだと全チャンネルが同じ分布になり、CNNの表現が落ちてしまいますが、手順2で学習可能なパラメータを導入し、それらを用いてスケーリングとシフトを行うことで、異なる特徴を持つ各チャンネルが独自のスケールとバイアスを持つことが可能になります。
そのため、ネットワークの表現力を維持することが可能になります。
バッチ正規化は、各中間層の出力を特定の「学習済みの固定分布」に変換するため、画像認識向けのCNNやTransformer系のネットワークなどを学習する際に、SGD学習の安定化と高速化が見込めます。
また、バッチ正規化が導入されて以来、深い層をもつCNNを学習する際に、学習率を高めに設定しても、勾配爆発や勾配消失問題を起こさないように学習を安定させることができました。
これにより、早い時間での学習の収束を実現しやすくなりました。
さらに、バッチ正規化はモデルの正則化効果もあるので(後述)、バッチ正規化を導入することによりモデルの汎化性能、つまり未知のデータに対する予測性能が向上します。(ただし、Transformer登場以降の巨大なDNNではドロップアウトが良く使用されています。)
3. バッチ正規化の詳細
3.1 計算手順
先ほども示した通り、バッチ正規化層は学習可能なパラメータを持っているので、(A)学習時と(B)テスト時で少し異なる計算を行います。
3.1.1(A)学習時の計算
線形層(全結合層 or 畳み込み層)の出力応答のc番目の出力$${x^{(c)}}$$に対して、1バッチ内で正規化したあと、スケーリング・シフトを行います。
SGD中に「スケール係数・シフト係数」を同時に学習することで、チャンネル内の出力$${x^{(c)}}$$は「学習済みの固定分布」から出力されるようになっていきます。(全てのパラメータが最適化され、それ以上変化しなくなった時点の分布が、「学習済みの固定分布」を意味します。)
推論時には、その固定分布から出力します。
詳しい説明
以下に、学習時のチャンネル$${c}$$での、バッチ正規化層の出力である$${BN_{\gamma, \beta}(x_i)}$$の計算手順を示します。
まず、チャンネル$${c}$$内のバッチ平均$${\mu_B}$$を計算します。
ここで、mは同一チャンネルに含まれる特徴ベクトルの個数、$${x_i}$$はバッチ正規化が適用される前の、特定のチャンネル$${c}$$における$${i}$$番目の入力データを表します。
$$
\mu_B \leftarrow \frac{1}{m} \sum_{i=1}^mx_i
$$
そして、チャンネル$${c}$$内のバッチ分散$${\sigma_B^2}$$を計算します。
$$
\sigma_B^2 \leftarrow \frac{1}{m} \sum_{i=1}^{m}(x_i - \mu_B)^2
$$
1つのミニバッチ内の$${N}$$個の各特徴マップに対して、それぞれ正規化とスケーリング・シフトを行います。
各マップ内内の各特徴量$${x_i}$$について1,2の順で計算します。
Nとmについて再度整理すると、Nはミニバッチ内の特徴マップの数を表しており、それぞれの特徴マップ(またはチャネル)内には、m個の特徴量(またはピクセル)が存在しているということです。
1.(チャンネルcでの)正規化
$$
\hat{x_i} \leftarrow \frac{1}{\sqrt{\sigma_B^2+\epsilon}}(x_i - \mu_B)
$$
ここで、$${\epcilon}$$はゼロ除算を防ぐための定数です。
2. スケーリングとシフト
$$
y_i \leftarrow \gamma \hat{x_i} + \beta \equiv BN_{\gamma, \beta}(x_i)
$$
ここでは、バッチ正規化の操作$${BN_{\gamma, \beta}(x_i)}$$を$${y_i}$$と定義し、これはスケーリング・シフトの結果$${\gamma \hat{x_i} + \beta}$$と一致する、という意味を持ちます。
スケーリングとシフトに使用されるパラメータ$${\gamma}$$と$${\beta}$$は、ネットワークが訓練データに対して学習を進める中で、誤差逆伝播法と確率的勾配降下法(SGD)によって更新されます。
具体的には、ネットワークの予測結果と正解との間の誤差を最小化するように、これらのパラメータが調整されます。
CNNの学習が終了すると、バッチ正規化層のパラメータの学習も終了します。
このアルゴリズムにより、各cチャンネルのニューロン出力値$${x_i}$$は、どれもバッチ内の正規化と$${\gamma}$$, $${\beta}$$で表現される固定分布へと少しずつ収束していきます。
上記で示した処理はc次元目の特徴$${x_1, x_2, …, x_{N \times W \times H}}$$の全てにおいて行われるため、学習が終了するとパラメータ$${\{\gamma^{(c)}, \beta^{(c)}\}_{c=1}^C}$$が手に入ります。
学習中に、各チャンネル$${c}$$での分布のスケール量$${\gamma^{(c)}}$$とシフト量$${\beta^{(c)}}$$がそれぞれ独立で固定されていきます。
それに伴い、全チャンネルの隠れ層の出力分布がそれぞれ固定分布に安定し、ニューラルネットワークの各層の入力と出力が一定の範囲に収まることにより「効率的な勾配降下=早くて安定した学習」が可能になります。
バッチ正規化の挿入箇所
線形層(全結合層または畳み込み層)では、バッチ正規化を挿入しない場合、つぎのように入力xがアフィン変換(ベクトル空間において線形変換と平行移動を組み合わせること)されます。
$$
h = f(W^Tx+b)
$$
ここで、$${f(\cdot)}$$はReLUなどの活性化関数、$${W}$$が結合重み行列、$${b}$$がバイアスベクトルです。
線形層の出力を活性化関数に通す前にバッチ正規化の変換$${BN_{\sigma, \gamma}(x)}$$を挿入します。
このとき、以下のような変換式になります。
$$
h=f(BN_{\sigma, \gamma}(W^Tx))
$$
ここでは、まず入力$${x}$$に結合重み$${W}$$を適用し($${W^Tx}$$)、次にその結果にバッチ正規化を適用します($${BN_{\sigma, \gamma}(W^Tx)}$$)。最後にその結果を活性化関数$${f(\cdot)}$$に通すという過程を示しています。
このような流れに従うと、バイアス項$${b}$$は不要となります。なぜなら、バッチ正規化の変換の一部である$${\beta}$$がバイアスと同じ機能を持つからです。
以上が学習時のバッチ正規化処理です。
3.1.2(B)テスト時の計算
学習時にはランダムなミニバッチ内を毎回正規化して、その各隠れ層の暫定的な正規化結果によってチャンネル$${c}$$ごとにパラメータ$${\gamma^{(c)}, \beta^{(c)}}$$を学習していました。
しかし、テスト時にはミニバッチを得ることはできないため、データセット全体から決定しておいた平均・分散を常に利用して$${BN_{\sigma, \gamma}(x_i)}$$を実施する必要があります。
そこで、まず学習時に、学習データ全体のチャンネル$${c}$$におけるバッチ平均$${E^{(c)}[x]}$$とその分散$${Var^{(c)}[x]}$$を、それぞれ算出します。
$$
E^{(c)}[x] \leftarrow E^{(c)}_B[\mu_B]
$$
$$
Var^{(c)}[x] \leftarrow \frac{m}{m-1}E_B^{(c)}[\sigma_B^2]
$$
次に、テスト時はこれらを用いて出力$${y^{(c)}=BN_{\gamma^{(c)}, \beta^{(c)}}(x^{(c)})}$$を以下のように計算します。
$$
y^{(c)}=\frac{\gamma ^ {(c)}}{\sqrt{Var^{(c)}[x]+\epsilon}} \cdot x^{(c)} + (\beta^{(c)} - \frac{\gamma E^{(c)}[x]}{\sqrt{Var^{(c)}[x]+\epsilon}})
$$
この式は、学習フェーズで学習されたパラメータ(スケールパラメータ $${\gamma^{(c)}}$$ とシフトパラメータ $${\beta^{(c)}}$$)と、全ての学習データを通じて算出した平均 $${E^{(c)}[x]}$$ と分散 $${Var^{(c)}[x]}$$ を使用して、入力データ $${x^{(c)}}$$ を正規化します。
この正規化のプロセスは次のように解釈できます。
$${\frac{\gamma ^ {(c)}}{\sqrt{Var^{(c)}[x]+\epsilon}} \cdot x^{(c)}}$$:
まず、入力データ $${x^{(c)}}$$ をスケール調整します。
これは入力データを標準偏差(ルート分散)で割り、スケールパラメータ $${\gamma^{(c)}}$$ を乗じます。
ここで、$${\epsilon}$$ はゼロ除算を防ぐための微小な数値です。$${\beta^{(c)} - \frac{\gamma E^{(c)}[x]}{\sqrt{Var^{(c)}[x]+\epsilon}}}$$:
次に、シフト調整を行います。
これはシフトパラメータ $${\beta^{(c)}}$$ から、平均値をスケール調整した値を引くことで行います。
この部分の結果は、新たな入力データが平均値と一致する場合に $${\beta^{(c)}}$$ となるように調整されます。
以上の2つの部分を合わせることで、新たな入力データ $${x^{(c)}}$$ が平均値と分散によって正規化され、それにスケール調整とシフト調整が加えられた出力 $${y^{(c)}}$$ が得られます。
こうすることで、テスト時はデータセット全体の統計を用いてのバッチ正規化を実行できます。
4. おわりに
今回はバッチ正規化について学びました。
ちょっと自分にはカロリー高かったので、前半と後半に分けようと思います。
結構詳しいところまで学びましたが、実際LLMのファインチューニングをする程度でしか使わないのなら、バッチ正規化がどういうものなのか、どんな効果があるのかだけ知っていればいい気もします。
次回はレイヤー正規化について残りの部分を勉強していきます。
それでは。
進捗上げてます
「#AIアイネス」で日々の作業内容を更新しています。
ぜひ覗いてみて下さい。
参考
バッチ正規化(Batch Normalization) とその発展型, CVMLエキスパートガイド, 林 昌希, 2021
更新履歴
タイトルの「レイヤー正規化」という表記は、「バッチ正規化」の間違いです。
元々レイヤー正規化について勉強する予定だったところを、途中でレイヤー正規化の発展元であるバッチ正規化にスケジュールを変更した際に、タイトルを変え忘れていました。