バッチ正規化とは
この記事では、バッチ正規化(Batch Normalization)の仕組みや利点の解説をします。
バッチ正規化は画像系のモデルでよく使われます。入力データの正規化や標準化とは異なり、隠れ層の間で使われます。特に畳み込み層と活性化関数の間で使われることが多いです。
まず、バッチ正規化には次の効果があります。
訓練が速く進む(損失値がはやく収束する)
重みの初期化に神経質にならなくとも良い
学習率を大きめに指定できる
バッチ正規化では以下の問題点もあります。
小さいバッチでは効果が出ない
再帰型ニューラルネットワークでは使えない
本記事では、上記の利点や問題点の理由を解説していきます。
バッチ正規化はバッチによる訓練をするところから始まるので、まずはその話から始めます。
バッチによる訓練
画像を使ったモデルの訓練を想定して話を進めます。
訓練用の画像がたくさんあるとします。メモリに収まりきらないので画像データを一定数のグループに分けて、それぞれのグループに対して処理を行うことにしたとします。
このように画像をまとめたものを一つの単位とし、バッチ(Batch)あるいはミニバッチ(Mini Batch)と呼びます。今記事では単にバッチと呼びます。また、分割する大きさをバッチサイズと呼びます。バッチサイズは色々と試して学習中の損失率の変化の具合などを見て決めます。
例えば、画像が16万枚あり、それを16枚ごとに分割したとするとバッチサイズは16で、データセットは1万個のバッチに分割されます。もっと大きなデータセットでもバッチ単位の処理になるので問題ありません。
またバッチ単位に小分けにすることによってニューラルネットワークのパラメータ(重みやバイアスなど)をより頻繁に更新することになるので学習が速く進みます。仮にバッチを使わずにデータセット全体を一度に処理できたとしてもパラメータの値を大幅に更新することはできないので、バッチ毎に処理しながら少しづつパラメータの更新をした方が結局は速く訓練が終わります。
つまり、ディープラーニングではバッチ処理を使用することで、大量のデータを効率的に処理することができます。
なお、バッチサイズとしては16、32、64、128、256…などがよく使われます。
画像モデルを訓練するときはバッチ単位での入力を行います。よって、画像モデルはバッチ形式での入力データを想定した作りになっています。そのため、テストやプロダクションで一つの画像だけを使って推論を行う場合は、1画像のみのバッチを作って入力します。
入力データのシャッフル
なお、バッチ毎にモデルの訓練を始める前に、訓練用の画像の順番をシャッフル(Shuffle)します。シャッフルとは順番をランダムに並び替えることです。これは画像の並び順に不自然な規則性がないようにするためです。
例えば、データセット内で犬の画像が160枚連続で順番に並んでいたすると、バッチサイズが16ならば10個のバッチで連続して犬の画像ばかりになります。このようにデータに偏りがあると学習がスムーズに行えません。犬ばかりたくさん現れたらモデルが犬の確率が高くなるようにパラメータを変更するでしょう。その後に猫ばかりが続いたら今度は猫の確率が高くなるようにパラメータを変更して犬のことは忘れてしまうかもしれません。こういう問題は実際によくあることです。よって訓練データは必ずシャッフルします。
なお、シャッフルしても偶然によって偏ったバッチが生じる可能性はあります。なので、エポックが終了(すべてのバッチを処理)したら、再び訓練データをシャッフルします。こうすることでエポックをたくさん繰り返してもバッチの中身は毎回異なるものになります。つまり、偶然の偏りがあるバッチが生じてもそれが次回のエポックへと継続することはありません。
画像の標準化
さて、入力画像をモデルへと渡す前に正規化や標準化をします。画像データでは標準化を行うことが多いです。訓練用のデータセットのすべての画像の平均と標準偏差を使います。また、標準化は画像の各チャンネルごとに行われます。
標準化によって訓練用の画像データの各チャンネルは平均が0で標準偏差が1の分布に従うように変換されます。
なお、テストデータなどで推論を行う際も訓練データセット全体からの平均と標準偏差を使って標準化を行います。これはテストやプロダクションのデータも訓練画像と同様の分布に従っていると想定しているからです。逆に言うと訓練用のデータセットは最終的な使用環境での画像データの分布を代表するように集めたものである必要があります。
共変量シフト
ここまでの話をまとめると、入力データはシャッフルされ、標準化され、バッチごとにモデルへと渡されます。さらに、モデルの隠れ層(畳み込み層など)が画像データを処理してから活性化関数を適用し、次の層へと渡していきます。
ところが、ここで問題が生じる可能性があります。隠れ層によってデータの分布が変わります。それが次の隠れ層を通過すると、さらに分布が変わります。つまり層を通過するたびにデータの分布が変わります。つまり転がる雪のように分布の変化が積み重なっていきます。
ニューラルネットワークのパラメータ(重みやバイアス)が更新され、次回のバッチが中間層を通り抜けると、データの分布は以前とは異なる変化をすることになります。つまり、ニューラルネットワークの中間層から見るとバッチ毎にデータの分布が大幅に変更されることになり学習を困難にします。
わかりやすく言い切ってしまうと、毎回分布が変化し続けるデータを渡されて「特徴量を抽出しろ」と言われているようなものです。大きくランダムに変化するデータの分布から役に立つ特徴を引き出すのは困難です。
つまり、せっかく入力データを標準化したとしても中間層から見るとデータの分布が毎バッチごとにスケールが異なったり平均が上下のシフトしていることになります。
これを内部共変量シフト(Internal Covariate Shift)あるいは単に共変量シフト(Covariate Shift)と呼びます。「内部」とつけるのは、分布のシフトがニューラルネットワークの内部で起きていることを強調しています。
「内部」ではない共変量シフトは外部の要因で起きます。例えば、若い年齢の男性のデータで訓練したモデルを年配の女性のデータに適応した場合は、外部要因による共変量シフトが生じます。これは訓練用のデータセットが最終的な使用環境での画像データの分布を代表するように集めたものであれば問題ありません。
よって、この記事では解説する共変量シフトは「内部」である前提なので、単に「共変量シフト」と呼びます。
バッチ正規化はこの共変量シフトを解決する手法として提案されました。
バッチ正規化の仕組み
バッチ正規化はニューラルネットワークの隠れ層の出力を一定の分布へと調節する手法です。通常は活性化関数が適用される前に行われます。よって活性化関数からの出力がバッチ毎に大きくシフトすることもありません。つまり、次の隠れ層への入力もバッチ毎に大きくシフトしないので全体として共変量シフトの問題が軽減されます。
ただし、曲者なのが「バッチ正規化」という名前です。実は、バッチ正規化では、いわゆる正規化ではなく標準化を使います。しかし、一旦意味を理解すれば迷うことはないでしょう。
以下は標準化の一般的な計算式になります。
$$
\begin{align*}
\text{平均値}\ \mu &= \frac{1}{N} \sum\limits_{i=1}^{N} x_i \\
\\
\text{分散}\ \sigma^2 &= \frac{1}{N} \sum\limits_{i=1}^{N}(x_i - \mu)^2 \\
\\
\text{標準偏差}\ \sigma &= \sqrt{\sigma^2} \\
\\
\text{標準化}(x_i) &= \frac{x_i - \mu}{\sigma}
\end{align*}
$$
また、$${\sigma}$$が0になった場合に0で割り算するとエラーになってしまうので、実装では小さな数$${\epsilon}$$を追加しておきます。
$$
{\sigma = \sqrt{\sigma^2 + \epsilon} \\}
$$
ここで、$${\epsilon}$$としては、0.00001(1e-5)などが指定されます。データに偏りがなくよくシャッフルされていれば$${\epsilon}$$がなくとも、通常は問題ないです。しかし、バッチ正規化は隠れ層で行われるので場合によってはデータの値がすごく近くなるとコンピュータが数値を収納する時の精度にも限界があるので問題が絶対に起きないとは言い切れません。よって、長い訓練をエラーによって台無しにされないようにこのような設定があります。
バッチ正規化では、バッチ内のデータを使って標準化の計算をします。よって、上記の$${N}$$を画像データで計算するとバッチ内のチャンネル毎における総ピクセル数(総画素数)になります。以下のように求められます。
$$
N = \text{1画像の1チャンネル内の画素数} \times \text{バッチサイズ}
$$
例えば、画像のサイズが28x28だとします。バッチサイズが16ならば各チャンネルの総画素数は28x28x16=12544となります。よって$${N = 12544}$$です。
なお、チャンネルの数は各層によって変わります。畳み込み層ではこれが特徴量マップを形成します。この記事では一貫してチャンネルと呼ぶことにしています。チャンネルの数は入力時は3チャンネルだったものが中間層では64チャンネルになったりもっと増えたり減ったりするわけです。各チャンネルがある特徴を表現しており、バッチ正規化はチャンネル毎に行われます。
このようにしてバッチ内の各チャンネルのデータ分布は平均が0で標準偏差が1になります。さらに、このバッチ内で標準化されたデータを$${\gamma}$$(ガンマ)でスケールし$${\beta}$$(ベータ)でシフトします。
$$
\text{バッチ正規化}(x_i) = \gamma \, \text{標準化}(x_i) + \beta
$$
$${\gamma}$$と$${\beta}$$はチャンネル毎のパラメータであり、それぞれ1と0に初期化されます。つまり、バッチ正規化は各チャンネルの標準化としてスタートしますが、$${\gamma}$$と$${\beta}$$の値は訓練を通して更新されるのでデータの分布をニューラルネットワークは調節することが可能です。つまり隠れ層の各チャンネルの出力はニューラルネットワークが選んだ平均と標準偏差を持つ分布に従うことになります。
この手法では共変量シフトが生じません。なぜならバッチ正規化ではまず標準化が行われるため以前の層からの分布の影響を受けないからです。このように、バッチ正規化は隠れ層の出力をチャンネル毎に調節し、活性化関数への入力データの分布を決めています。
例えば、活性化関数としてReLUを使っているとします。ニューラルネットワークがあるチャンネルのデータに正の$${\beta}$$を与えれば、ReLUを通過しやすくなります。逆に、負の$${\beta}$$を与えれば、ReLUを通過しにくくなります。よって、どの特徴を通しやすくするのかを決めることができます。
なお、隠れ層にもバイアスがあるのですが、標準化によってバイアスの意味がなくなります。また、機能としてはバッチ正規化の$${\beta}$$と重複します。よって、通常は隠れ層のバイアスは使わないように設定して余計なパラメータを減らします。
バッチ正規化の基本的な仕組みは以上となりますが、モデルの評価を行う際にもう一点注目すべき機能があります。
バッチ正規化を評価モードで使う
テストやプロダクションで画像処理を行う場合には、訓練中に計算した平均と標準偏差からの指数加重平均を使います。ここではこの指数加重平均を解説します。
モデルを評価モードで使う時は、入力バッチからの平均や標準偏差は使いません。代わりに、訓練中に計算した平均と標準偏差の値を使います。ただし、平均や標準偏差はバッチ毎に異なるので、以下のような指数加重平均の計算を行い、訓練データ全体における隠れ層の出力の平均と標準偏差の近似とします。
$$
\hat{x} \leftarrow (1 - \text{m}) \times \hat{x} + \text{m} \times x
$$
$${\hat{x}}$$は現段階で推測された平均あるいは分散の値です。ここでの$${x}$$は訓練中のバッチ内の平均あるいは分散の値です。$${m}$$はモーメンタム(Momentum)で0から1の間の値が使われます。よく使われる値は0.1などですが、設定により変更可能です。よって現在の推測値と新しいバッチからの値とで加重平均をとっていることになります。
$$
\begin{align*}
\text{平均の推測値} &\leftarrow (1 - m) \times \text{平均の推測値} + \text{m} \times \text{バッチの平均} \\
\\
\text{分散の推測値} &\leftarrow (1 - m) \times \text{分散の推測値} + \text{m} \times \text{バッチの分散}
\end{align*}
$$
その際に0.1をモーメンタムとするならば9割は現状の推測値から1割は最新のバッチからの平均あるいは分散の値を取り入れることになります。バッチ毎にこの更新が生じるので常に最新のバッチからの値を取り込んでいきます。新しいバッチからの平均や分散は、最新の隠れ層の重みによるのでより推測値も徐々に最新の重みに対応した値へと移動していきます。その意味で上記のような計算を移動平均とも呼ばれます。ある程度長く訓練すれば平均と分散の推測値も安定した値になります。
よって、上記の指数加重平均の式によって推測される平均と分散の値がバッチ毎に更新されることにより全体の平均と分散に近づいていことになります。
以上により、訓練中に平均と分散の予測値を更新し続けることで、訓練データの分布に適応した平均と分散を準備することができます。よって、推論時での正規化でも同じ値を使うことができるため、訓練時と推論時の性能を統一することができます。もちろん、ここでもテストやプロダクションでの画像データの分布は訓練画像の分布とある程度は似たものであるという仮定がなされています。
次にバッチ正規化の利点と問題点を解説します。
バッチ正規化の利点と問題点
最大の利点としては、ニューラルネットワークの中間層を通過するバッチの分布が安定するので訓練が速く進む(損失値がはやく収束する)ようになります。
また、「勾配消失」や「勾配爆発」といった問題が起きにくくなります。活性化関数への入力分布が極端に大きくなったり小さくなったりしないのでReLUで全ての値がゼロになったり、Sigmoidで勾配がほぼゼロになるようなことが起こりにくいです。また、入力値によって重みの更新が大きくなりすぎることもありません。
さらに、バッチ正規化では各層への入力値の分散が抑えられるためパラメータの更新がスムーズになります。入力画像の標準化の話でも同様の解説をしました。
また、標準化されることによって外れ値(Outlier)をあまり気にする必要がなくなります。なぜなら多くのデータは平均値の近くに集まるようになるからです。よって外れ値による重みの更新への影響が少ないので学習率を大きめに設定できるようになります。
さらに活性化関数への入力分布がニューラルネットワークによって決められるので隠れ層の重みの初期化(初期分布)を活性化関数に合わせて調節する必要がなくなります。つまり、重みの初期化に神経質にならなくとも良いわけです。
バッチ正規化の問題点としては、バッチサイズがある程度の大きさである必要があることです。標準化をする際に平均や標準偏差をバッチ内のデータだけで計算するのですが、データが多い方がより全体のデータの分布を反映するからです。
もし仮にデータセットの画像をすべてニューラルネットワークに通して、隠れ層での平均と標準偏差を計算できれば正確な標準化になるわけですが、バッチを使う時点でそれは不可能です。また、データセット全体を使えたとしても、それはおすすめではありません。バッチを使う利点が無くなってしまうからです。よって、バッチサイズをある程度大きくすることで、サンプルサイズが大きくなるとバッチによる平均や標準偏差が良い近似として使えます。
なお、バッチ正規化は再帰型ニューラルネットワークではうまく機能しません。なぜなら、再帰の構造ではシーケンス(時系列など順番に並んだデータ)の長さが毎回異なるためにバッチ毎に平均や分散を計算するのが困難だからです。そのため、再帰型バッチ正規化というLSTMのために工夫された手法もあります。しかし、単純なバッチ正規化の適用とは異なります。
むしろ、再帰型ニューラルネットワークではレイヤー正規化が使われます。ちなみに、レイヤー正規化はトランスフォーマーでも使われています。
関連記事
この記事が気に入ったらチップで応援してみませんか?