AdamWとは
オプティマイザはニューラルネットワークの訓練を効率的に行うために重要な役割を果たしますが、その中でもAdam(アダム)は有名です。Adamの論文は2015年に発表されました。そして、多くの人々がほぼデフォルト的にAdamを使うようになりました。
しかし、その約3年後にAdamWの論文が発表され、Adamの実装における重要な欠陥を指摘し、その解決策を提示しました。
今回の記事では、Adam と AdamW について解説します。まず、Adam を解説する前にSGDについて復習します。その次に、SGDと比較しながら Adam の仕組みを解説します。さらに、Adam にどのような問題があるのかを説明します。最後に、AdamW が Adam の欠陥をどのように克服したのかを紹介します。
では、さっそく始めましょう。
SGDの仕組み
更新ルール
SGD(Stochastic Gradient Descent、確率的勾配降下法)は、最も基本的な最適化アルゴリズムです。訓練中の各更新ステップでランダムに選択された訓練データ(ミニバッチ)を使って勾配を計算し、その勾配方向にパラメータを更新します。
SGDの更新ルールは以下の通りです:
$$
\theta_{t+1} = \theta_{t} - \eta \, g_t
$$
ここで、$${\theta}$$(セータ)はモデルのパラメータ、$${\eta}$$(エータ)は学習率、$${g_t}$$は損失関数(コスト関数)のパラメータ$${\theta}$$による変化率で以下に述べるように損失関数の勾配(Gradient)から計算できます。
また、$${t}$$は訓練中のステップを意味します。つまり、ステップ$${t+1}$$におけるパラメータの値$${\theta_{t+1}}$$は、ステップ$${t}$$におけるパラメータの値$${\theta_t}$$と勾配$${g_t}$$によって決まります。
なお、SGDにおける学習率$${\eta}$$はすべてのパラメータに対して同じ値が使われます。基本的には固定値です。モーメンタムという概念もありますが、ここでは省略します。興味のある方は、こちらを参照してください。
上記では、一つのパラメータを扱っていますが、一般にはパラメータは複数個あるので、各パラメータを$${i}$$で表現すれば、上記のパラメータ更新式は以下のようになります。
$$
\theta_{t+1,i} = \theta_{t,i} - \eta \, g_{t, i}
$$
なお、マイナスがついているのは、勾配とは逆の方向、つまり損失関数が減少する方向にパラメータを更新するためです。
勾配との関係
パラメータがN個あるとすると、以下のようになります。
$$
\begin{aligned}
\theta_{t+1,1} &= \theta_{t,1} - \eta \, g_{t, 1} \\
\theta_{t+1,2} &= \theta_{t,2} - \eta \, g_{t, 2} \\
\vdots \\
\theta_{t+1,i} &= \theta_{t,i} - \eta \, g_{t, i} \\
\vdots \\
\theta_{t+1,N} &= \theta_{t,N} - \eta \, g_{t, N} \\
\end{aligned}
$$
なお、$${g_{t, 1}, g_{t, 2}, \dots, g_{t, i} \dots, g_{t, N}}$$は、各パラメータに対する損失関数の変化率なので、各パラメータに対する損失関数の偏微分を計算して求めることができます。
$$
g_{t, i} = \dfrac{\partial J(\boldsymbol{\theta}_t)}{\partial \theta_{t,i}}
$$
ここで、すべてのパラメータをまとめて$${\boldsymbol{\theta}_t = [\theta_{t,1}, \theta_{t,2}, \dots, \theta_{t,N]}}$$としています。また、$${J(\boldsymbol{\theta}_t)}$$は損失関数です。パラメータを変更すると損失値が変わるので、損失関数をパラメータの関数として扱っています。
以上から、すべてのパラメータの更新式をまとめて次のように勾配を使って書き表せます。
$$
\boldsymbol{\theta}_{t+1} = \boldsymbol{\theta}_t - \eta \nabla J(\boldsymbol{\theta}_t)
$$
勾配は偏微分のリストです。
$$
\nabla J(\boldsymbol{\theta}_t) = \left[ \dfrac{\partial J(\boldsymbol{\theta}_t)}{\partial \theta_1}, \dfrac{\partial J(\boldsymbol{\theta}_t)}{\partial \theta_2}, \dots, \dfrac{\partial J(\boldsymbol{\theta}_t)}{\partial \theta_N} \right]
$$
なお、微分、偏微分、勾配に関してはこちらでも解説しています。
学習率の問題
さて、SGDの主な問題点は、すべてのパラメータに対して同じ学習率を使用するため、一部のパラメータが他のパラメータよりも速く学習される可能性があることです。
具体的な例を考えてみましょう。
深層学習モデルのパラメータは、通常、異なる特徴を学習します。例えば、畳み込みニューラルネットワーク(CNN)の初期の層は、一般的にエッジや色などの低レベルの特徴を学習します。一方、ネットワークの後半の層は、より高レベルの特徴(例えば、物体の形状やテクスチャ)を学習します。
これらの異なる層は、異なる速度で学習することが望ましいかもしれません。初期の層は、より単純なパターンや特徴を抽出します。これに対し、後半の層では、それ以前の層の特徴を利用してより複雑な特徴を学習します。よって、初期の層は、後半の層よりも収束が速くなりがちだと考えられます。逆に言うと、後半の層は学習が収束するのに、より多くの時間を必要とすることになります。
しかし、SGDではすべてのパラメータに対して同じ学習率が適用されます。その結果、初期の層が過学習する一方で、後半の層が未学習のままであるという状況が生じる可能性があります。これは、特に大規模なデータセットや複雑なモデルで問題となる可能性が高いと言えるでしょう。
この問題を解決するために、各パラメータに対して個別に学習率を調整する手法がいくつか提案されました。つまり、各パラメータが最適な速度で学習することが可能になります。これらのオプティマイザについては、こちらでも解説しています。
そのなかでもAdamは、一般的なタスクに対して良好に機能するものとして人気となりました。
Adamの仕組み
学習率を個別調整
Adam(Adaptive Moment Estimation)は、Diederik P. KingmaとJimmy Baによって提案された最適化アルゴリズムです。Adamのアルゴリズムは、各パラメータに対して適応的な(個別に調節した)学習率を使用することで、学習の効率性を向上させます。
具体的には、Adamは過去の勾配の情報を追跡し、それを用いて各パラメータの学習率を調整します。過去の勾配の情報は、過去の勾配の平均(一次モーメント)と過去の勾配の分散(二次モーメント)の形で保持されます。
一次モーメントは、過去の勾配の方向を示すため、パラメータの更新方向を決定します。一方、二次モーメントは、過去の勾配の変動を示すため、パラメータの更新量(つまり、学習率)を決定します。たとえば、過去の勾配が大きかったパラメータは、学習率を小さくすることで更新のスピードを抑制します。逆に、過去の勾配が小さかったパラメータは、学習率を大きくすることで更新のスピードを加速します。
このように、Adamは各パラメータに対して適応的な学習率を使用することで、学習の効率性を向上させます。これにより、各パラメータが最適な速度で学習することが可能になります。
以下に、Adamの仕組みを数式で表現します。
一次モーメント
まず、過去の勾配の平均(一次モーメント)は次のようになります。
$$
m_{t,i} = \beta_1 m_{t-1,i} + (1 - \beta_1) g_{t,i}
$$
$${m_{t,i}}$$は現在の一次モーメントです。これは、一つ前のステップ$${t-1}$$の一次モーメント$${m_{t-1},i}$$と現在の勾配からの変化率$${g_{t,i}}$$の加重平均になっています。その際の重みが$${\beta_1}$$です。
$${\beta_1}$$は低減率とも呼ばれ、どの程度過去の一次モーメントの影響を低減する(あるいは継続する)のかを決めています。よく、0.9といった値が使われます。つまり、過去の一次モーメントの影響が0.9倍、現在の勾配の影響が0.1倍となります。これにより、一次モーメントは過去の勾配の影響を大きく受けつつも、新たな勾配の情報を取り入れることができます。また、現在の勾配が突然大きく変化しても学習が不安定にならない効果があります。逆に、現在の勾配が突然小さくなってもこれまでの勢いをある程度継続してパラメータの更新を続けることができます。
なお、Adamの一次モーメントは、過去の勾配の情報を指数関数的に減衰させる(0.9倍を繰り返すので)ので、指数移動平均(Exponential Moving Average、EMA)とも呼ばれます。
以上のように、Adamは個別のパラメータに対しての一次モーメントを保持しています。
二次モーメント
次に、過去の勾配の分散(二次モーメント)は次のようになります。
$$
v_{t,i} = \beta_2 \, v_{t-1,i} + (1 - \beta_2) g_{t,i}^2
$$
$${v_{t,i}}$$は現在の二次モーメントです。これは、一つ前のステップ$${t−1}$$の二次モーメント$${v_{t−1,i}}$$と現在の勾配の二乗$${g_{t,i}^2}$$の加重平均になっています。二次モーメントを計算をするときの加重平均の重みが$${\beta_2}$$であり、減衰率です。よく、0.999と設定されます。
Adamの二次モーメントは、いわゆる統計学の分散ではなく、過去の勾配の二乗の指数移動平均(Exponential Moving Average、EMA)です。これは、新しい勾配の情報が加わるたびに過去の勾配の影響を指数関数的に減衰させる平均のことを指します。1次モーメントの指数移動平均と同様の考え方です。
二次モーメントは、過去の勾配の変動を含みます。二次モーメントが大きい場合、過去の勾配の絶対値が大きい傾向があるので、学習率を小さくすることで更新のスピードを抑制し学習を安定化させます。逆に、二次モーメントが小さいパラメータは、学習率を大きくすることで更新のスピードを加速します。
モーメントの修正
一次モーメントと二次モーメントは、初期値が0であるためにステップ数がすくない間は0に近い値になりがちです。なぜなら、0の0.9倍や0.999倍は0なので、一次モーメントと二次モーメントの値が0から離れた値になるまでしばらくかかります。新しい値の影響が大きくなるのをしばらく待つ必要があり、訓練の初期の学習が遅くなる恐れがあります。
よって、次の修正(バイアス補正)を加えます。
$$
\begin{aligned}
\hat{m}_{t,i} &= \frac{m_{t,i}}{1 - \beta_1^t} \\
\hat{v}_{t,i} &= \frac{v_{t,i}}{1 - \beta_2^t}
\end{aligned}
$$
$${\beta_1 = 0.9}$$、$${\beta_2 = 0.999}$$だとすると、$${\beta_1^t}$$と$${\beta_2^t}$$は、初期の頃は1に近く、上式の分母が0に近くなるので各モーメントの値が大きくなるように補正されます。ステップ$${t}$$が大きくなると$${\beta_1^t}$$と$${\beta_2^t}$$共に0へと低減していくので徐々に補正がなくなっていきます。
これによって初期の一次モーメントと二次モーメントが0に近い値になりがちな状況を回避しています。
パラメータの更新
最後に、これらのバイアス補正を行った一次モーメントと二次モーメントを用いて、パラメータの更新を行います。
$$
\theta_{t+1,i} = \theta_{t,i} - \eta \dfrac{\hat{m}_{t,i}}{\sqrt{\hat{v}_{t,i}} + \epsilon}
$$
$${\eta}$$は固定の学習率ですが、補正された一次モーメントと二次モーメントによって全体としての更新率が調節されます。
つまり、補正された一次モーメント$${\hat{m}_{t,i}}$$の大きければ全体としての更新率も大きくなるように調整されます。逆に、小さければ更新率も小さくなります。また、補正された二次モーメントが大きければ更新率は小さくなります。逆に、小さければ更新率は大きくなります。
なお、$${\epsilon}$$はゼロによる割り算にならないようにして、数値安定性をはかるための小さな値です。よく、1e-8 が使われます。
まとめると、Adamは、過去の勾配の一次モーメント(過去の勾配の指数移動平均)と二次モーメント(過去の勾配の二乗の指数移動平均)のバランスによって各パラメータの更新量を決定しています。
以上によって、Adamは各パラメータが最適な速度で学習することを可能にしています。
Adamの問題点
重み減衰の仕組み
Adamのアルゴリズムは多くの深層学習タスクで優れた性能を発揮します。しかし、重み減衰(weight decay)に関しての問題点が指摘されました。
重み減衰は、過学習を防ぐための一般的な手法で、モデルのパラメータが大きくなりすぎるのを防ぐために用いられます。具体的には、損失関数にパラメータを2乗したものを加えることで、パラメータの大きさを制限します。
これは機械学習モデルの訓練中によく使用される手法でL2正則化とも呼ばれます。損失関数に重みの二乗和に基づくペナルティ項を追加します。これにより、損失関数と重みの二乗和が小さくなるように学習が進むので、モデルの重みが大きくなりすぎることに起因する過学習を防ぐことができます。
L2正則化を含む損失関数$${J_{L2}}$$は以下のようになります:
$$
J_{L2}(\boldsymbol{\theta}) = J(\boldsymbol{\theta}) + \dfrac{\lambda}{2} ||\boldsymbol{\theta}||^2
$$
ここで、$${\lambda}$$はハイパーパラメータです。PyTorchでは、これをweight_decayと呼んでいます。
L2正則化を含む損失関数$${J_{L2}}$$をあるパラメータ$${\theta_i}$$で偏微分すると以下になります。
$$
\dfrac{\partial J_{L2}(\boldsymbol{\theta})}{\partial \theta_i} = \dfrac{\partial J(\boldsymbol{\theta}) + \dfrac{\lambda}{2} ||\boldsymbol{\theta}||^2}{\partial \theta_i} = g_i + \lambda \theta_i
$$
つまり、L2正則化を含む損失関数$${J_{L2}}$$の勾配の計算をするには、L2正則化なしの損失関数$${J}$$の勾配に、パラメータ$${\boldsymbol{\theta}}$$を $${\lambda}$$(weight_decay)倍したものを足すことになります。
$$
\nabla J_{L2}(\boldsymbol{\theta}) = \nabla J(\boldsymbol{\theta}) + \lambda \boldsymbol{\theta}
$$
このように重み減衰は、元の損失関数に重みのL2正則化項を加えることによって行われます。パラメータが大きくなると勾配に加わる項も大きくなるので、その分パラメータの更新によってパラメータが小さくなる訳です。つまり、モデルの重みが大きくなることにペナルティが課され、過学習を防ぐ効果があります。
しかし、Adamのアルゴリズムでは、この重み減衰が適切に機能しないことが指摘されています。
重み減衰の問題点
Adamは各パラメータに対して適応的な学習率を使用するため、L2正則化(パラメータの大きさ)による影響が相対的に小さくなってしまいます。その結果、重み減衰の効果が弱まり、過学習を防ぐ能力が低下する可能性があります。
具体的には、重み減衰が勾配に直接加えられる形で実装されているため、適応的な学習率によって重み減衰の効果が相対的に小さくなってしまいます。つまり、Adamのアルゴリズムが各パラメータに対して適応的な学習率を使用することで重み減衰が適切に機能しないという問題が生じます。
もう一度、Adamの重み更新式を見てみましょう。
$$
\theta_{t+1,i} = \theta_{t,i} - \eta \dfrac{\hat{m}_{t,i}}{\sqrt{\hat{v}_{t,i}} + \epsilon}
$$
ここで、$${\theta_{t,i}}$$は時刻$${t}$$での重み、$${\eta}$$は学習率、$${\hat{m}_{t,i}}$$と$${\hat{v}_{t,i}}$$はバイアス補正を行った一次モーメントと二次モーメント、$${\epsilon}$$は数値安定性のための小さな値です。
重み減衰を勾配に直接加える形で実装すると、重み更新式は以下のようになります:
$$
\theta_{t+1,i} = \theta_{t,i} - \eta \dfrac{\hat{m}_{t,i} + \lambda \theta_{t,i}}{\sqrt{\hat{v}_{t,i}} + \epsilon}
$$
ここで、$${\lambda}$$は重み減衰の強さを制御するパラメータです。
この更新式では、重み減衰項$${\lambda \theta_{t,i}}$$が適応的な学習率$${\dfrac{\eta}{\sqrt{\hat{v}_{t,i}} + \epsilon}}$$によってスケーリングされます。よって、大きくなったり小さくなったりするわけで、そもそものL2正則化による効果が正しく反映されなくなる訳です。これでは、L2正則化の本来の目的である「パラメータの大きさを制限する」こととは異なる振る舞いを示すため、過学習を防ぐ能力が低下する可能性があります。
この問題を解決するために、AdamWが提案されました。AdamWでは、重み減衰をAdamのアルゴリズムに直接組み込むことで、適応的な学習率が重み減衰の邪魔をするのを防ぎます。
これにより、AdamWはAdamの持つ効率性を保ちつつ、過学習をより効果的に防ぐことが可能になります。
AdamWによる改善
更新ルール
AdamWの重み更新式は以下のようになります:
$$
\theta_{t+1,i} = (1 - \eta \lambda) \theta_{t,i} - \eta \dfrac{\hat{m}_{t,i}}{\sqrt{\hat{v}_{t,i}} + \epsilon}
$$
この更新式では、重み減衰項が適応的な学習率によってスケーリングされるのではなく、重み自体に直接適用されます。つまり、重み減衰は各ステップで重みを一定の割合で減衰させる形で実装されています。これにより、重み減衰の効果が適応的な学習率によって相対的に小さくなるという問題が解消されています。
ちなみに、AdamWの論文のタイトルは、「Decoupled Weight Decay Regularization」(分離された重み減衰の正則化)となっています。このタイトルは、AdamW の目指すことを明確に示していると言えるでしょう。
また、AdamWでは、重み減衰の強さを制御するパラメータ$${\lambda}$$が学習率$${\eta}$$によってスケーリングされています。これにより、学習率が大きいときは重み減衰も強く、学習率が小さいときは重み減衰も弱くなるという、直感的な振る舞いを実現しています。
したがって、AdamW は Adam の持つ効率性を保ちつつ、過学習をより効果的に防ぐことが可能になります。これは、深層学習モデルの訓練において重要な性質です。この知識を持つことは、最適化アルゴリズムをより深く理解し、適切に使用するために役立つでしょう。
以上が、Adam と AdamW の違いとその背後にある理論的な考え方についての解説になります。
まとめ
この記事では、Adamの問題点として重み減衰の問題を解説した上で、それを解決したAdamWの仕組みを紹介しました。PyTorchの実装では、AdamW の weight_decay は、デフォルトで 0.01 となっています。
ちなみに、PyTorchなどのAdamの実装では weight_decay がデフォルトで0に設定されています。このことが、Adamの重み減衰の問題がしばらく気づかれなかった理由かもしれません。なぜなら、weight_decayが0に設定されていると、重み減衰(L2正則化)は全く適用されず、その結果、Adamの重み減衰の問題は表面化しないからです。そのため、この設定がデフォルトであるために、問題が見過ごされていた可能性はあります。
しかし、重み減衰は過学習を防ぐための重要な手段であり、多くの深層学習モデルでは非ゼロのweight_decayが使用されます。そのため、Adamの重み減衰の問題は、実際のモデル訓練においては無視できない問題であり、それが解決されたことでAdamWは広く受け入れられています。
関連記事
この記事が気に入ったらチップで応援してみませんか?