見出し画像

【ステップ7備忘録】ゼロから作るDeep Learning ❺【生成モデル編】ステップ7 変分オートエンコーダー(VAE)を読む

ゼロから作るDeep Learning ❺【生成モデル編】ステップ7では、変分オートエンコーダー(VAE)を解説しています。この記事では、本書を読むための予備知識のご紹介と、本書とは別のデータセット(Fashion-MNISTを使用して、変分オートエンコーダー(VAE)を実装します👨‍🎓


変分オートエンコーダー(VAE)について

変分オートエンコーダー(VAE)は、生成モデルの一種で、深層学習を用いて複雑なデータを生成したり、潜在空間を学習したりするモデルです。

VAEの仕組み

VAEは、大きく分けてエンコーダとデコーダの2つの部分から構成されています。

  • エンコーダ: 入力データを潜在空間に圧縮する。この潜在空間は、データの潜在的な特徴を捉えていると考えられます。

  • デコーダ: 潜在空間から元のデータに近いものを生成する。

特徴

  • 潜在空間: 潜在空間は、連続的な空間であることが多く、データの多様性を捉えることができます。

  • 生成モデル: VAEは生成モデルであり、新しいデータ点を生成することができます。

  • 確率的: VAEは、潜在変数に確率分布を仮定するため、より柔軟な表現が可能になります。

VAEの利点

  • 生成: 既存のデータセットに似た新しいデータを生成できる。

  • 表現学習: データの潜在的な特徴を学習できる。

  • 欠損値補完: 欠損しているデータを補完できる。

  • 異常検知: 異常なデータを検出できる。

VAEの目的

VAE(変分オートエンコーダ)の根本的な目的は、観測データの生成過程をモデル化し、その生成モデルから新たなデータを生成することです。言い換えれば、観測データの背後にある潜在的な構造(潜在変数)を推定し、その潜在変数から元のデータを再構成することを目指します。

ELBO (Evidence Lower Bound) とは

ELBOは、VAEの学習において非常に重要な役割を果たす変分下限と呼ばれるものです。VAEでは、観測データの対数尤度を直接最大化することが難しいという問題があります。そこで、この対数尤度を下から抑えることができるELBOを導入し、代わりにELBOを最大化することで、間接的に対数尤度を最大化するというアプローチを取ります。

なぜELBOを最大化するのか

  • 対数尤度の近似: ELBOは、観測データの対数尤度の下限となります。つまり、ELBOを最大化することで、対数尤度も間接的に大きくなることが期待できます。

  • 計算可能性: 対数尤度を直接計算するのは困難ですが、ELBOは計算可能な形で表すことができます。

  • VAEの目的との整合性: ELBOを最大化することは、VAEの目的である「観測データの生成過程のモデル化」と「潜在変数の推定」という二つの目標を達成することにつながります。

ELBOの構成

ELBOは、大きく分けて再構成誤差KLダイバージェンスの2つの項から構成されます。

  • 再構成誤差: デコーダが出力したデータと、元の観測データとの間の誤差を表します。この項が小さいほど、デコーダが元のデータをより正確に再構成できていることを意味します。

  • KLダイバージェンス: エンコーダが出力した潜在変数の分布と、事前分布(通常は標準正規分布)との間の距離を表します。この項が小さいほど、潜在変数が事前分布に近い、より自然な分布になっていることを意味します。

VAEの学習

VAEの学習は、このELBOを最大化することによって行われます。具体的には、勾配法を用いて、エンコーダとデコーダのパラメータを更新していきます。

  • 再構成誤差の減少: デコーダが元のデータをより正確に再構成できるように、デコーダのパラメータを更新します。

  • KLダイバージェンスの減少: 潜在変数が事前分布により近づくように、エンコーダのパラメータを更新します。

VAEによるFashion MNISTの画像生成

VAE(変分オートエンコーダー) を用いて、Fashion MNIST データセットから画像を生成するPyTorchのコードを以下に示します。

インポート

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

ハイパーパラメータ設定とデータセットの準備

# ハイパーパラメータ設定
latent_dim = 20  # 潜在空間の次元数
batch_size = 128
epochs = 10

# データセットの準備
transform = transforms.ToTensor()
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

VAEモデル定義

# VAEモデル定義
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        # エンコーダ
        self.encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU()
        )

        self.mean = nn.Linear(128, latent_dim)
        self.log_var = nn.Linear(128, latent_dim)

        # デコーダ
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.encoder(x)
        mu = self.mean(x)
        logvar = self.log_var(x)

        # 再標本化
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std

        return self.decoder(z), mu, logvar

訓練ループ

# 訓練ループ
for epoch in range(epochs):
    for data in train_loader:
        images, _ = data
        images = images.view(-1, 784)

        # 順伝搬
        outputs, mu, logvar = model(images)

        # 再構成誤差とKLダイバージェンスの計算
        recon_loss = F.binary_cross_entropy(outputs, images, reduction='sum')
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        # 損失の計算とバックプロパゲーション
        loss = recon_loss + kl_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, loss.item()))

Epoch
[1/10], Loss: 23727.6172 Epoch
[2/10], Loss: 23949.9434 Epoch
[3/10], Loss: 23564.4570 Epoch
[4/10], Loss: 22671.5273 Epoch
[5/10], Loss: 23062.0684 Epoch
[6/10], Loss: 22448.5000 Epoch
[7/10], Loss: 23727.3340 Epoch
[8/10], Loss: 23657.2227 Epoch
[9/10], Loss: 23683.1855 Epoch
[10/10], Loss: 23490.1328

訓練の過程

潜在空間からサンプリングして画像生成

# 潜在空間からサンプリングして画像生成
with torch.no_grad():
    # 潜在空間からランダムな点をサンプリング
    z = torch.randn(64, latent_dim)
    # デコーダで画像を生成
    generated_imgs = model.decoder(z)
    # 画像を表示
    plt.figure(figsize=(10, 10))
    for i in range(64):
        plt.subplot(8, 8, i+1)
        plt.imshow(generated_imgs[i].reshape(28, 28), cmap='gray')
        plt.axis('off')
    plt.show()
VAEにより生成されたFashion-MNIST

まとめ

VAEは、深層学習を用いた強力な生成モデルであり、様々な分野で応用されています。潜在空間を学習することで、データの構造を深く理解することができ、新たなデータ生成や表現学習が可能になります。なお「ゼロから作るDeep Learning ❺【生成モデル編】」では、丁寧にELBOについて解説し、具体的なVAEの実装方法が示されています。ぜひ、詳しくは、「ゼロから作るDeep Learning ❺【生成モデル編】」をお読みになる事をお勧めします😃

いいなと思ったら応援しよう!