GANを実装してみる


GAN (Generative Adversarial Network) とは?

敵対的生成ネットワークとも言います。
GANは、Ian Goodfellow氏らによって2014年に提案された新しい形式の生成モデルです。このモデルは、生成器(Generator)と識別器(Discriminator)という2つのネットワークから構成されています。

  • 生成器 (Generator): ランダムなノイズからデータを「生成」する役割を持ちます。

  • 識別器 (Discriminator): 提供されたデータが本物(実際のデータセットからのもの)か、生成器が生成した偽物かを「識別」する役割を持ちます。

GANの特性

  1. 非監視学習: 従来のディープラーニングはラベリング済みの大量のデータによる学習(教師あり学習)が必要ですが、GANは学習データを自ら作り出して学習する、教師なし学習で使用されるアルゴリズムのひとつです。

  2. 強力な生成能力: GANは、特に画像生成の分野でその力を発揮しており、非常に高解像度で現実的な画像を生成することが可能です。

基本的なGAN(Vanilla GAN)

+---------------+             +---------------+             +---------------+
|               | Noise       |               | Fake Data   |               |
|   Generator   +------------->   GAN Model   +-------------> Discriminator |
|               |             | (untrainable) |             |               |
+---------------+             +-------^-------+             +-------+-------+
                                |    |                             |
                                |    |                             |
                                |    | Real Data                   |
                                |    |                             |
                                |    +-----------------------------+
                                |
                                |
                           Real or Fake

GANの基本的なモデル。2つのネットワーク、Generator(生成器)とDiscriminator(判別器)を使って、新しいデータを生成する。

  • Generator: ランダムノイズを受け取り、データ(たとえば画像)を生成します。

  • Discriminator: 提供されたデータが本物(実際のデータセットからのもの)か偽物(Generatorが生成したもの)かを判断する。

  • 目的: GeneratorはDiscriminatorを騙そうと努力し、Discriminatorは本物と偽物を正確に識別しようとします。

StyleGAN

半導体大手の米NVIDIAの研究チームが2018年に発表した手法です。顔や特定のオブジェクトの高解像度の画像を生成するためにデザインされたGAN。独自のスタイル制御メカニズムを持つ。
StyleGAN2、StyleGAN2-ADA、StyleGAN3と性能が向上した手法が開発されています。

  • 特徴: スタイル制御、アダプティブインスタンスノーマリゼーション、マッピングレイヤーなどの技術を使用。

  • 目的: さまざまな「スタイル」の影響を各レベルの詳細で制御できるようにする。

CycleGAN

画像のスタイル変換が得意な手法。例えば、夏の風景を冬の風景に変換するなど。

  • 特徴: 片方向の変換(A -> B)だけでなく、逆方向の変換(B -> A)も学習します。従って、変換の「サイクル」を保持することができる。

  • 目的: 一方のドメインから他方のドメインへの変換を学習し、元のドメインに戻ることで変換が一貫していることを確認します。

Conditional GAN (cGAN)

条件付きGANとも呼ばれます。
2014年にarXivで公開された論文 Conditional Generative Adversarial Netsで提案された生成手法です。
生成されるデータの種類や特性を制御するための条件(例: ラベルや情報)をGANに提供します。

  • 特徴: GeneratorとDiscriminatorの両方に条件を供給します。

  • 目的: 条件に基づいて特定のタイプのデータを生成する。例えば、数字「3」のラベルを条件として提供すると、Generatorは数字「3」の画像を生成しようとします。

注意点

GANは学習が難しく、多くの場合、ハイパーパラメータの調整や特定のアーキテクチャの使用が必要です。モデルが収束しない場合やモード崩壊(常に同じようなデータを生成する)などの問題が発生することがあります。

この講義では、GANの基本的な概念を紹介し、シンプルな1Dデータを生成するためのGANの実装を通じて、GANの動作原理を理解することを目的としています。

GANを実装してみる

使用するMNISTデータセット

GANのデモンストレーションや教育的な利用の際によく使用されるデータセットは、MNIST です。MNISTは、手書き数字の画像(0から9まで)からなるデータセットで、それぞれの画像は28x28ピクセルのグレースケール画像です。

MNISTは以下の理由から、特にGANの初学者向けに人気があります。

  1. シンプル: 画像がグレースケールであり、解像度も低いため、学習が速く、必要な計算リソースも比較的少ないです。

  2. 標準的: 他の多くの機械学習の研究やチュートリアルで使用されており、結果を比較しやすいです。

  3. 利用が容易: 多くの機械学習のライブラリやフレームワークで、MNISTをダウンロードして使うための関数やツールが提供されています。

実装したGANがMNISTデータセットを使用して学習する場合、生成器はランダムなノイズから始めて、徐々に真のMNISTの数字の画像に似た画像を生成するようになります。一方、識別器は、真のMNISTの画像と生成器が生成した偽の画像を見分ける能力を向上させるように学習します。

具体的なイメージとしては、生成器は初めはランダムなピクセルの塊を生成するだけかもしれませんが、学習が進むと、それが3や7などの手書き数字に似た画像になっていきます。

MNISTの画像がどんなものなのか、表示してみましょう

import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist

# MNISTのデータセットをロードする
(train_images, train_labels), (_, _) = mnist.load_data()

# 最初の25枚を表示する
plt.figure(figsize=(10, 10))
for i in range(25):
  plt.subplot(5, 5, i+1)
  plt.imshow(train_images[i])
  plt.axis('off')
plt.tight_layout()
plt.show()

上記のコードを実行すると、こんな画像が表示されます。
これがMNISTの手書き数字画像です。解像度低めですね。

基本的なGANでの実装

Vanilla GANと呼ばれる、基本的なGANで実装していきます。

from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Flatten, Dense, Conv2D, LeakyReLU, Dropout, Reshape, Conv2DTranspose
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.optimizers import Adam, legacy
import numpy as np

# MNISTデータセットをロードする
(x_train, _), (_, _) = mnist.load_data()
# 画像のピクセルの値を[-1, 1]の範囲にスケーリング
x_train = (x_train.astype(np.float32) - 127.5) / 127.5
# チャネルの次元追加
x_train = np.expand_dims(x_train, axis=-1)

# ジェネレータを構築する関数
def build_generator():
  # ノイズの設定
  noise_shape = (100,)
  model = Sequential()

  # ノイズから特徴マップを生成
  model.add(Dense(128 * 7 * 7, activation='relu', input_shape=noise_shape))
  # 1Dの出力を2Dの特徴マップに変換
  model.add(Reshape((7, 7, 128)))
  # 画像のサイズを増加させる
  model.add(Conv2DTranspose(128, kernel_size=3, strides=2, padding='same'))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Conv2DTranspose(128, kernel_size=3, strides=2, padding='same'))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Conv2D(1, kernel_size=3, padding='same', activation='tanh'))

  return model

# ディスクリミネータを構築する関数
def build_discriminator():
  # 入力画像の形状を指定
  img_shape = (28, 28, 1)
  model = Sequential()

  # 画像の特徴を学習
  model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=img_shape, padding='same'))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Dropout(0.25))
  model.add(Conv2D(128, kernel_size=3, strides=2, padding='same'))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Dropout(0.25))
  model.add(Flatten())
  model.add(Dense(1, activation='sigmoid'))

  return model

# オプティマイザを初期化する
optimizer = legacy.Adam(learning_rate=0.0002, beta_1=0.5)

# ディスクリミネータをビルドする。損失関数を指定してコンパイル
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

# ジェネレータをビルドする。
generator = build_generator()

# ジェネレータのノイズを設定
z = Input(shape=(100,))
img = generator(z)

# 学習中にディスクリミネータの重みを更新しないようにする。
discriminator.trainable = False

# 画像の真偽を判定するディスクリミネータ
validity = discriminator(img)

# ジェネレータとディスクリミネータを結合したモデルを作成
combined = Model(z, validity)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

# 学習のパラメータ設定
batch_size = 128
epochs = 1000
half_batch = batch_size // 2

# 学習
for epoch in range(epochs):
  # 真の画像のランダムなサブセット
  idx = np.random.randint(0, x_train.shape[0], half_batch)
  real_imgs = x_train[idx]

  # ノイズから偽の画像を生成
  noise = np.random.normal(0,1, (half_batch, 100))
  fake_imgs = generator.predict(noise)

  # 真と偽のラベル作成
  real_labels = np.ones((half_batch, 1))
  fake_labels = np.zeros((half_batch, 1))

  # ディスクリミネータをトレーニング
  d_loss_real = discriminator.train_on_batch(real_imgs, real_labels)
  d_loss_fake = discriminator.train_on_batch(fake_imgs, fake_labels)
  d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

  # ジェネレータをトレーニング
  noise = np.random.normal(0, 1, (batch_size, 100))
  valid_labels = np.ones((batch_size, 1))
  g_loss = combined.train_on_batch(noise, valid_labels)

  print(f"{epoch}/{epochs} [D loss: {d_loss[0]} | D Accuracy: {100 * d_loss[1]}] [G loss: {g_loss}]")

学習させたジェネレータを使用し、画像を生成

学習させたジェネレータに、手書き数字の画像を生成してみてもらいます。

import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model

# 保存されたモデルをロードする
loaded_generator = load_model('generator_model.h5')

# 画像を生成する関数を作成する
def generate_images(generator, num_images=10):

  # 画像の生成に必要なノイズを設定
  # 平均0, 標準偏差1の正規分布から、ランダムなノイズを生成する。
  noise = np.random.normal(0, 1, (num_images, 100))

  # ノイズをジェネレータに入力、画像を生成
  generated_imgs = generator.predict(noise)

  # 生成された画像のピクセル値が[-1, 1]の範囲になるので、[0, 1]に変換する
  generated_imgs = 0.5 * generated_imgs + 0.5

  return generated_imgs

# 画像を10枚生成してみる
images = generate_images(loaded_generator, 10)

fig, axs = plt.subplots(1, 10, figsize=(20, 2))
for i in range(10):
  # グレースケールなので、cmap="gray"
  axs[i].imshow(images[i, :, :, 0], cmap="gray")
  axs[i].axis('off')
plt.show()

Epoch1000なので、何となく数字っぽいなという程度の画像が生成できました。


いいなと思ったら応援しよう!