誤差逆伝播法を簡単に説明してみる
簡単な例を使って誤差逆伝播法をおさらいします。
モデルを更新する目的
体重から身長を予測することを考えます。そのための訓練データとして、次のように入力値と正解値がペア(対)として与えられているとします。
$$
(\text{体重}_1, \text{身長}_1), (\text{体重}_2, \text{身長}_2), \ldots, (\text{体重}_N, \text{身長}_N)
$$
数式にしやすいように、これらの値を$${(x_i, y^*_i) = (\text{体重}_i, \text{身長}_i)}$$と表現することにします。なお、*印をつけたのは、後で正解の値とモデルによる予測の値とを区別しやすくするためです。
モデルは体重$${x_i}$$から身長を予測して$${y_i}$$を出力します。
$$
y_i = \text{model}(x_i)
$$
ここではモデルの具体的な形として、直線の式を考えます。
$$
y_i = a x_i + b
$$
$${a}$$は直線の傾きで、$${b}$$は直線の切片です。$${a}$$と$${b}$$をまとめて、このモデルのパラメータと呼びます。これはモデルの一部と考えられ、モデルを更新することはパラメータを調節することです。
我々が目指すのは、モデルの予測$${y_i}$$が正解$${y^*_i}$$に近くなるようにモデルのパラメータを調節しモデルを更新することです。
予測を変えると誤差が変化する
そこで、モデルの予測誤差を評価するために、例として平均二乗誤差(Mean Squared Error、MSE)を使用します。
$$
\text{MSE} = \frac{1}{N} \sum\limits_{i=1}^N \left(y^*_i - y_i\right)^2
$$
ここで、$${N}$$はデータポイントの数です。
このMSEの値を小さくするために、ある予測値$${y_j}$$を変化させた場合のMSEの反応を考えます。つまり、MSEを$${y_j}$$で偏微分します。
$$
\begin{aligned}
\frac{\partial \text{MSE}}{\partial y_j} &= \frac{\partial }{\partial y_j} \frac{1}{N} \sum\limits_{i=1}^N \left(y^*_i - y_i\right)^2 \\
&= \frac{1}{N} \sum\limits_{i=1}^N \frac{\partial }{\partial y_j} \left(y^*_i - y_i\right)^2 \\
&= \frac{1}{N} \sum\limits_{i=1}^N 2 \left(y^*_i - y_i\right) \frac{\partial (-y_i)}{\partial y_j} \\
\end{aligned}
$$
ここで、$${y_i}$$が$${y_j}$$に依存するのは、$${i = j}$$の時だけです。したがって、次のように簡略化できます。
$$
\begin{aligned}
\frac{\partial \text{MSE}}{\partial y_j} &= \frac{1}{N} \cdot 2\left(y^*_j - y_j\right) \cdot (-1) \\
&= \frac{2}{N} \left( y_j - y^*_j \right)
\end{aligned}
$$
得られた式から、次のことが理解できます。
$${y_j < y^*_j}$$の場合は$${\dfrac{\partial \text{MSE}}{\partial y_j} < 0}$$なので、$${ y_j}$$を増やすとMSEが減少する
$${y_j > y^*_j}$$の場合は$${\dfrac{\partial \text{MSE}}{\partial y_j} > 0}$$なので、$${ y_j}$$を減らすとMSEが減少する
この理解に従って予測値$${y_j}$$を調節したいのですが、$${y_j}$$を直接変更することはできません。
パラメータを変えると予測が変わる
そこで、各パラメータ($${a}$$と$${b}$$)を調節して$${y_j}$$を変更し、MSEの値を小さくすることを考えます。
まず、$${y_j}$$を$${a}$$と$${b}$$に関してそれぞれ偏微分します。
$$
\begin{aligned}
\frac{\partial y_j}{\partial a} &= \frac{\partial (a x_j + b)}{\partial a} = x_j \\[2ex]
\frac{\partial y_j}{\partial b} &= \frac{\partial (a x_j + b)}{\partial b} = 1
\end{aligned}
$$
この結果をまとめると、以下になります。
$${a}$$を1増やすと予測値$${y_j}$$が$${x_j}$$だけ増える
$${b}$$を1増やすと予測値$${y_j}$$が$${1}$$だけ増える
よって、$${a}$$と$${b}$$を調節し、$${y_j}$$の値を変更することで、MSEを最小化することが可能です。そこで、偏微分の連鎖律(チェーンルール)を使って$${a}$$と$${b}$$の各パラメータによるMSEの変化率を表現します。
$$
\begin{aligned}
\frac{\partial MSE}{\partial a} &= \sum\limits_{j=1}^N \frac{\partial MSE}{\partial y_j} \cdot \frac{\partial y_j}{\partial a} = \frac{2}{N} \sum\limits_{j=1}^N \left( y_j - y^*_j \right) x_j \\
\frac{\partial MSE}{\partial b} &= \sum\limits_{j=1}^N \frac{\partial MSE}{\partial y_j} \cdot \frac{\partial y_j}{\partial b} = \frac{2}{N} \sum\limits_{j=1}^N \left( y_j - y^*_j \right)
\end{aligned}
$$
つまり、パラメータ$${a}$$や$${b}$$を変化させると$${y_j}$$の値が変わり、それに連なってMSEが変化します。この変化を全ての予測値$${y_1, \ldots, y_N}$$に対して計算し足し合わせれば、パラメータ$${a}$$や$${b}$$を変化させた時のMSEへの影響度がわかります。
勾配と逆方向にパラメータを調節する
また、各パラメータによるMSEの変化率から以下の関係が成り立ちます。
$${\dfrac{\partial \text{MSE}}{\partial a} < 0}$$場合は、$${a}$$を増やすとMSEが減少する。
$${\dfrac{\partial \text{MSE}}{\partial a} > 0}$$場合は、$${a}$$を減らすとMSEが減少する。
$${\dfrac{\partial \text{MSE}}{\partial b} < 0}$$場合は、$${b}$$を増やすとMSEが減少する。
$${\dfrac{\partial \text{MSE}}{\partial b} > 0}$$場合は、$${b}$$を減らすとMSEが減少する。
これを簡潔に表現するために、MSEのパラメータに関する勾配を次のように定義します。
$$
\nabla \text{MSE} = \left(\frac{\partial \text{MSE}}{\partial a}, \frac{\partial \text{MSE}}{\partial b}\right)
$$
つまり、この勾配は各パラメータによる誤差の偏微分をベクトルにしたものです。これによって、パラメータ空間においてどの方向にパラメータを動かすと、誤差が増えるのかがわかります。よって、勾配の逆方向にパラメータを調節することで誤差を減少させることができます。これが勾配降下法の基本的な仕組みです。
なお、勾配を計算する過程では、誤差からパラメータへと偏微分を連鎖させています。これはフィードフォワード(順方向)とは逆向きになっており、この手法が「誤差逆伝播法」と呼ばれる所以です。
もっと複雑に連鎖する計算や非線形の変換などを含めても基本的な考え方は同じです。
関連記事
この記事が気に入ったらチップで応援してみませんか?