深層学習:ニューラルネットワークにおけるコスト関数最小化法 Classical Momentum,NAG,Sutskever NAG

勾配降下法は、コスト関数$${J({\bm w})}$$の勾配にのみ注目しているため、局所的最小値に取り込まれやすい。このため、確率勾配降下法などで、重みを大きく振らしているが、この振動のためにコスト関数の形状によって収束に時間がかかることがある。
 これを克服するために、勾配だけでなく、重みの変化割合(速度)に応じた情報も入れて更新するのが、モーメンタム法である。
 以下、重みの更新時を$${t}$$とし、$${t}$$から$${t+1}$$での重みの更新を
$${{\bm w}_{t+1}={\bm w}_t + \Delta {\bm w}_{t+1}}$$と表し、
$${ \Delta {\bm w}_{t+1}={\bm w}_{t+1}-{\bm w}_t}$$とする。

Classical Momentum

モーメンタムとして$${\mu\Delta{\bm w}_{t}}$$を入れる。$${\mu}$$は質量項とも呼ばれ、質量$${\times}$$速度で物理学のモーメンタムのように解釈されることから、モーメンタム法と呼ばれる。
$${ \Delta {\bm w}_{t+1}=\mu_t{\Delta \bm w}_{t}-\eta_t{\bm \nabla}J({\bm w}_t)}$$
で与えられ、
$${{\bm v}_{t+1} = \Delta {\bm w}_{t+1}}$$の速度は、
$${{\bm v}_{t+1}=\mu_t{\bm v}_{t}-\eta_t{\bm \nabla}J({\bm w}_t)}$$
となる。

NAG

上記のモーメンタム法では、$${{\bm w}_t}$$での勾配を計算し、モーメンタム項を入れて更新するが、NAG では最初に$${{\bm w}_t}$$をモーメンタム方向に変更した後勾配をとり、モーメンタム項を入れて更新する。
$${{\bm v}_t=\mu_t{\bm v}_{t-1} - \eta_t{\bm \nabla}_{{\bm w}_t}J({\bm w}_t-\mu{\bm v}_{t-1})}$$
と表せるが、実際はモーメンタムを入れた変更を、以下のように最初に行い、
$${{\bm \varphi}_{t+1}={\bm w}_{t} - \eta_t{\bm \nabla}_tJ({\bm w}_t)}$$
これを用いて、重みの更新をする二段階更新を行う。
$${{\bm w}_{t+1}={\bm \varphi}_{t+1}+\mu_t({\bm \varphi}_{t+1} - {\bm \varphi}_{t})}$$
 NAGの利点は、重みの変更が、コスト関数の局所最小近傍で形状に対して穏やかになることから雑音が少なくなることである。

Sutskever Nestrov Momentum

上記のNAGをRNNに特化した最小問題解法であり、更新ステップを組み替えていることが特徴である。
NAGの更新ステップを$${t+1}$$と$${t+2}$$で連続に書く。

  1. $${{\bm \varphi}_{t+1}={\bm w}_{t} - \eta_t{\bm \nabla}_tJ({\bm w}_t)}$$

  2. $${{\bm w}_{t+1}={\bm \varphi}_{t+1}+\mu_t({\bm \varphi}_{t+1} - {\bm \varphi}_{t})}$$

  3. $${{\bm \varphi}_{t+2}={\bm w}_{t+2} - \eta_t{\bm \nabla}_{t+1}J({\bm w}_{t+1})}$$

  4. $${{\bm w}_{t+2}={\bm \varphi}_{t+2}+\mu_t({\bm \varphi}_{t+2} - {\bm \varphi}_{t+1})}$$

ここで、2の$${{\bm w}_{t+1}}$$と3の$${{\bm \varphi}_{t+2}}$$を一つの更新ステップに組み替える。
$${{\bm \varphi}}$$と$${\eta}$$をシフトバックして、新たな更新ステップの二段階手続きは以下のようになる。
$${{\bm w}_{t+1}={\bm \varphi}_{t}+\mu_t({\bm \varphi}_{t} - {\bm \varphi}_{t-1})}$$
$${{\bm \varphi}_{t+1}={\bm w}_{t+1} - \eta_t{\bm \nabla}_{t+1}J({\bm w}_{t+1})}$$
$${{\bm v}_{t}={\bm \varphi}_{t}-{\bm \varphi}_{t-1}}$$と書くと、
$${{\bm w}_{t+1}={\bm \varphi}_{t}+\mu_t{\bm v}_t}$$で与えられるから、
$${\displaystyle{{\bm \varphi}_{t+1}={\bm \varphi}_{t}+\mu_t{\bm v}_t - \eta_t\frac{\partial J ({\bm \varphi}_t + \mu_t{\bm v}_t)}{\partial({\bm \varphi}_t + \mu_t{\bm v}_t)} }}$$
と更新される。
ここで、モーメンタム法の表記に則り、
$${{\bm v}_{t+1}=\displaystyle{\mu_t{\bm v}_t - \eta_t\frac{\partial J ({\bm \varphi}_t + \mu_t{\bm v}_t)}{\partial({\bm \varphi}_t + \mu_t{\bm v}_t)}}}$$
と書けば、
$${\bm{\varphi}_{t+1}={\bm \varphi}_t + {\bm v}_{t+1}}$$
となる。
この更新の実装は、$${{\bm \varphi}}$$をダイレクトに更新するのではなく、
$${\tilde{{\bm \varphi}}_{t+1}={\bm \varphi}_t + \mu{\bm v}_t}$$
の変数を導入し、
$${\tilde{{\bm \varphi}}_{t+1}=\displaystyle{{\bm \varphi}_t +{\bm v}_{t+1} + \mu_t^2{\bm v}_t - \eta_t\mu_t\frac{\partial J(\tilde{{\bm \varphi}}_t)}{\partial \tilde{\bm \varphi}_t} }}$$
$${\displaystyle{={\bm \varphi}_t +\mu_t{\bm v}_t -\eta_t \frac{\partial J(\tilde{{\bm \varphi}}_t)}{\partial \tilde{\bm \varphi}_t}+\mu_t^2{\bm v}_t - \eta_t\mu_t\frac{\partial J(\tilde{{\bm \varphi}}_t)}{\partial \tilde{\bm \varphi}_t} }}$$
よって、$${\tilde{{\bm \varphi}}}$$の更新は、
$${\tilde{{\bm \varphi}}_{t+1}\displaystyle{=\tilde{{\bm \varphi}}_t + \mu_t^2{\bm v}_t - (1+\mu_t)\eta_t \frac{\partial J(\tilde{{\bm \varphi}}_t)}{\partial \tilde{\bm \varphi}_t} }}$$
で行われる。

この記事が気に入ったらサポートをしてみませんか?