ニューラルネットワークの予測の不確実性(stochastic variational inference・概要編)


はじめに

stochastic variational inference でニューラルネットワークの予測の不確実性を算出する方法の概要を説明します。

事前準備(変分ベイズ)

前提知識として必要なベイズ推論と変分推論について説明します。

ベイズ推論

通常のニューラルネットワークでは、出力は変数ですが、ベイズ推論ではラベル$${y}$$の分布$${p(y|x)}$$を考えます。
分布が尖った形になっている場合は予測の不確実性が低く、裾が広い場合は不確実性が高いことになります。

ラベルの分布は、パラメータの分布$${p(\theta)}$$を考え、積分で$${\theta}$$を消去することで求めます。

$$
p(y|x) = \int p(y|x, \theta)p(\theta) {\rm d}\theta
$$

変分推論

パラメータの事後分布$${p(\theta|X,Y)}$$を、計算しやすい分布$${q(\theta)}$$で近似することを考えます。
近似分布$${q(\theta)}$$は、$${p(\theta|X,Y)}$$とのKLダイバージェンスが小さければ、よい近似となります。

$$
D_{KL}[q(\theta)||p(\theta|X,Y)] = \int q(\theta)\ln\frac{q(\theta)}{p(\theta|X,Y)} {\rm d}\theta
$$

しかし、直接最小化することは難しいため、データの周辺分布 $${\ln p(X,Y)}$$ とその下界(ELBO)$${\mathcal{L}(\xi)}$$ を考えます。
これらと、上記KLダイバージェンスには、下記の関係があることが知られています。

$$
\ln p(X,Y) = \mathcal{L}(\xi) + D_{KL}[q(\theta;\xi)||p(\theta|X,Y)]
$$

ここで、$${\xi}$$は$${q(\theta)}$$の形を決めるパラメータで、変分推論では$${\xi}$$を最適化します。

この関係性を用いて、KLダイバージェンスを最小化する問題を、ELBOを最大化する問題に置き換えます。
ELBOは下記で定義されます。

$$
\mathcal{L}(\xi) = \int q(\theta;\xi)\ln \frac{p(X,Y,\theta)}{q(\theta;\xi)} {\rm d}\theta
$$

同時確率を条件付き確率に書き換えると、$${p(X,Y,\theta)=p(Y|X,\theta)p(X|\theta)p(\theta)}$$となり、$${X}$$と$${\theta}$$は独立のため $${p(X,Y,\theta)=p(Y|X,\theta)p(X)p(\theta)}$$ と書くことができます。
これをELBOに代入すると下記のようになります。

$$
\mathcal{L}(\xi) = \int q(\theta;\xi) \ln p(Y|X, \theta) {\rm d}\theta + \int q(\theta;\xi) \ln p(X) {\rm d}\theta + \int q(\theta;\xi) \ln \frac{p(\theta)}{q(\theta;\xi)} {\rm d}\theta 
$$

$$
= \mathbb{E}{q(\theta;\xi)}[\ln p(Y|X, \theta)] + \ln p(X) - D{KL}[q(\theta;\xi)||p(\theta)]
$$

$${\ln p(X)}$$は定数なので、ELBOを最大化するために、下記の負の対数尤度と$${q(\theta)}$$とパラメータの事前分布とのKLダイバージェンスの和を最小化すればよいことが分かります。

$$
\mathbb{E}{q(\theta;\xi)}[-\ln p(Y|X, \theta)] + D{KL}[q(\theta;\xi)||p(\theta)]
$$

Stochastic Variational Inference

負の対数尤度の期待値を最小化するために、各バッチで$${q(\theta;\xi)}$$からパラメータをサンプリングして、そのパラメータでの負の対数尤度を最小化します。
$${q(\theta;\xi)}$$は勾配降下法で最適化できる必要があるため、$${q(\theta;\xi)}$$の実装では下記のような方法が用いられます。

reparametrization trick

$${q(\theta;\xi)}$$を、平均$${\mu}$$、標準偏差$${\sigma^2}$$の正規分布$${\mathcal{N}(\mu,\sigma^2)}$$からサンプリングする場合、まず、平均0、標準偏差1の正規分布に従うノイズ$${\epsilon \sim \mathcal{N}(0,1)}$$をサンプリングし、$${\theta = \mu + \sigma^2 \epsilon}$$をパラメータとして用います。
$${\mu}$$と$${\sigma}$$を勾配降下法で最適化します。

flipout

上記の手法だと、ミニバッチ内のサンプルで$${\epsilon}$$が共有されてしまう問題があります。
そこで、ノイズにランダムな符号ベクトル$${r_ns_n^T}$$をかけ、$${\theta=\mu + \sigma^2\epsilon r_ns_n^T}$$を用いることで、学習の効率化を行います。

回帰タスクの対数尤度

サンプリングされたパラメータ$${\theta}$$でのデータ$${x}$$の予測を$${p(y|x,\theta)=\mathcal{N}(f(x|\theta),\sigma^2)}$$、正解を$${y}$$とすると、対数尤度は$${\ln p(y|x,\theta)=-\frac{\ln 2\pi\sigma^2}{2} - \frac{(y-f(x|\theta))^2}{2\sigma^2}}$$となります。
つまり、負の対数尤度を最小化するためには、2乗誤差を最小化すればよいことになります。

不確実性の算出

近似分布$${q(\theta;\xi)}$$からn個のパラメータ$${{\theta_1, \cdots, \theta_n}}$$をサンプリングし、事後分布の期待値と分散を計算します。

$$
\mu(y) = \frac{1}{n}\sum_{i=1}^n f(x|\theta_i)
$$

$$
{\rm Var}(y) = \sigma^2 + \frac{1}{n}\sum_{i=1}^n f(x|\theta_i)^T f(x|\theta_i) - \mu(y)^T \mu(y)
$$

学習のまとめ

Stochastic variational inferenceでは、各ミニバッチで近似分布$${q(\theta)}$$からパラメータ$${\theta}$$をサンプリングし、平均二乗誤差と事前分布と近似分布のKLダイバージェンスを最小化します。

$$
\mathcal{L} = \sum_{i=1}^n (y_n - f(x_n|\theta))^2 + \alpha D_{KL}[q(\theta)||p(\theta)]
$$

ここで、$${\alpha}$$は、2つの項のバランスを調整するパラメータです。

参考資料

  • D. P. Kingma and M. Welling, Auto-Encoding Variational Bayes, ICLR, 2014

  • Y. Wen et al., Flipout: Efficient Pseudo-Independent Weight Perturbations on Mini-batches, ICLR, 2018.

いいなと思ったら応援しよう!