見出し画像

VQ-VAE

VQ-VAE(Vector Quantized Variational AutoEncoder)は、従来のVAE(Variational AutoEncoder)を拡張した深層学習モデルです。VAEは知っていたけどVQとついた瞬間に嫌悪感…😨ということで勉強してみました。応用分野は広く画像生成、音声合成、テキスト生成、音素や話者の識別など多岐に渡るのでこれはやらねば…


主な特徴

  1. 離散的な潜在表現: VQ-VAEは、連続的な潜在変数ではなく、離散的な潜在変数を使用します

  2. ベクトル量子化: VQとは連続空間に存在するベクトルを、有限個の代表ベクトルへ離散化する操作です。連続的なデータを離散的な表現に変換することで、データの効率的な圧縮が可能になります。

  3. posterior collapseの回避: 通常のVAEで発生しやすい「後部崩壊(posterior collapse)」問題を回避できます

  4. 高速な処理: 圧縮された潜在空間でのサンプリングにより、特に大規模な画像生成において処理速度が大幅に向上します

そもそもベクトル量子化とは?

Vector Quantizationは深層学習モデルで重要な役割を果たし、データの効率的な圧縮と高品質な生成を実現する革新的な技術として注目されています。簡単に言えば、VQは、たくさんの情報を少ない情報で表す方法です。

VQの手順

  1. グループ分け: たくさんの点(情報)があるとき、似ている点をグループにまとめます。

  2. 代表を選ぶ: 各グループから1つの代表的な点を選びます。これを「コードワード」と呼びます。

  3. 置き換え: 元の点を、一番近い代表点に置き換えます。


イメージ(引用元

イメージ

  • 教室に30人の生徒がいて、みんな違う身長です。先生は3つのグループ(小さい、中くらい、大きい)に分けたいと思っています。

  • 先生は各グループの真ん中くらいの身長の子を代表に選びます。

  • そして、他の子たちは自分に一番近い身長の代表と同じグループになります。

これがVQの基本的な考え方です。たくさんの情報(30人の生徒)を、少ない情報(3つのグループ)で表しているのです。

数式で表すと…

要はK-meansでやっていることと同じですが、下記の計算を行います。

  1. コードブックの定義
    コードブック$${C}$$は$${K}$$個の代表ベクトル(コードワード)cici​から構成されます:$${C=\{c_1,c_2,...,c_K\},\ c_i\in\mathbb{R}^D}$$
    ここで、$${D}$$はベクトルの次元数です。

  2. 量子化プロセス
    入力ベクトル$${x\in\mathbb{R}^D}$$に対して、最も近い代表ベクトルを選択します:$${q(x)=\mathop{\rm arg~max}_⁡{c_i \in C}\|x−c_i\|^2}$$ここで、$${\|\cdot\|^2}$$はユークリッド距離を表します。

  3. 逆量子化
    量子化されたインデックスiiから元のベクトル空間への逆写像は以下のように定義されます:$${Q(x)=c_i}$$, where $${i=q(x)}$$

  4. 再構成誤差
    量子化による再構成誤差は以下のように計算されます:
    $${E=\|x−Q(x)\|^2}$$

  5. コードブックの最適化(学習過程)
    コードブックは通常、K-means法などを用いて最適化されます。目的関数は以下のようになります:
    $${\mathop{\rm min}⁡_{C}\sum _{x\in X}\|x−Q(x)\|^2}$$
    ここで、$${X}$$は訓練データセットです。この最適化は反復的に行われ、各反復で以下の2つのステップを繰り返します:

    1. 各データポイント $${x}$$ を最も近いコードワードに割り当てる

    2. 各コードワード $${c_i}$$​ を、それに割り当てられたデータポイントの平均で更新する。

スクリプトで見ると

これをスクリプトにしてみると以下のようなことをしています。


import numpy as np

class VectorQuantizer:
    def __init__(self, num_codewords, dim):
        self.num_codewords = num_codewords
        self.dim = dim
        self.codebook = np.random.randn(num_codewords, dim)

    def quantize(self, vectors):
        distances = np.sum((self.codebook[:, np.newaxis] - vectors) ** 2, axis=2)
        indices = np.argmin(distances, axis=0)
        quantized = self.codebook[indices]
        return quantized, indices

    def train(self, vectors, num_iterations=100, learning_rate=0.01):
        for _ in range(num_iterations):
            quantized, indices = self.quantize(vectors)
            for i in range(self.num_codewords):
                mask = (indices == i)
                if np.any(mask):
                    self.codebook[i] += learning_rate * (vectors[mask].mean(axis=0) - self.codebook[i])

# 使用例
num_codewords = 16
dim = 3
vq = VectorQuantizer(num_codewords, dim)

# トレーニングデータ
data = np.random.randn(1000, dim)

# トレーニング
vq.train(data)

# 量子化
test_data = np.random.randn(10, dim)
quantized, indices = vq.quantize(test_data)

print("元のデータ:")
print(test_data)
print("\n量子化されたデータ:")
print(quantized)
print("\nインデックス:")
print(indices)

代表点だけで表現力は落ちない?

VQでは、代表点(コードワード)を決めるだけでも、以下の理由により表現力をある程度維持することができます:

  • 密度マッチング特性
    VQは、データの密度分布に合わせて代表点を配置します。これにより:

    • 頻出するデータパターンには多くの代表点が割り当てられます。

    • まれなデータパターンには少ない代表点が割り当てられます。

この特性により、データの重要な特徴や構造を効率的に捉えることができます。

  • 最適な代表点の選択
    VQでは、通常K-meansなどのクラスタリングアルゴリズムを使用して、データ分布を最もよく表す代表点を選択します。これにより:

    • データ空間を効果的に分割し、各領域の特徴を捉えます。

    • 量子化誤差を最小化し、元のデータとの類似性を保ちます。

  • 階層的・反復的アプローチ
    より高度なVQ手法では、以下のような工夫により表現力を向上させています:

    • 残差量子化(RQ): 複数の量子化段階を使用して、ベクトルの近似を段階的に改善します

    • プロダクト量子化(PQ): ベクトルを部分ベクトルに分割し、それぞれを独立に量子化します

これらの手法により、より細かい特徴や構造を捉えることが可能になります。

  • 適切なコードブックサイズの選択
    コードブック(代表点の集合)のサイズを適切に選択することで、圧縮率と表現力のバランスを取ることができます。

このようにVQの表現力を維持する工夫が多くあります。ただし、VQにも限界があり、完全に表現力を維持することは困難です。特に高次元データや複雑なパターンを持つデータでは、一定の情報損失は避けられません。そのため、アプリケーションの要件に応じて、圧縮率と表現力のトレードオフを考慮する必要があります。

VQ-VAEの仕組み

やっと本題のVQ-VAEへ。ここまで来たらあとは普通のVAEと一緒です。
まずは原論文のアーキテクチャを眺めて見ましょう。

アーキテクチャ(原論文より引用)

スクリプトでみると

import tensorflow as tf
from tensorflow import keras
import numpy as np

class VectorQuantizer(keras.layers.Layer):
    def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
        super().__init__(**kwargs)
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.beta = beta

    def build(self, input_shape):
        # コードブックの初期化
        self.embeddings = self.add_weight(
            shape=(self.embedding_dim, self.num_embeddings),
            initializer="random_normal",
            trainable=True,
            name="embeddings"
        )

    def call(self, x):
        # 入力を平坦化
        input_shape = tf.shape(x)
        flattened = tf.reshape(x, [-1, self.embedding_dim])

        # 最も近い埋め込みベクトルを見つける
        distances = tf.reduce_sum(flattened ** 2, 1, keepdims=True) - 2 * tf.matmul(flattened, self.embeddings) + tf.reduce_sum(self.embeddings ** 2, 0, keepdims=True)
        encoding_indices = tf.argmin(distances, axis=1)
        encodings = tf.one_hot(encoding_indices, self.num_embeddings)
        quantized = tf.matmul(encodings, tf.transpose(self.embeddings))

        # 量子化ベクトルを元の形状に戻す
        quantized = tf.reshape(quantized, input_shape)

        # Straight-through estimatorを使用
        quantized = x + tf.stop_gradient(quantized - x)

        # コードブック損失とコミットメント損失を計算
        # コードブック損失(埋め込みベクトルを入力に近づける)
        e_latent_loss = tf.reduce_mean((tf.stop_gradient(quantized) - x) ** 2)
        # コミットメント損失(エンコーダー出力を埋め込みベクトルに近づける)
        q_latent_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
        loss = q_latent_loss + self.beta * e_latent_loss

        # 損失を更新
        self.add_loss(loss)

        return quantized

# エンコーダーの定義
def get_encoder():
    return keras.Sequential([
        keras.layers.Input(shape=(28, 28, 1)),
        keras.layers.Conv2D(32, 3, strides=2, padding="same", activation="relu"),
        keras.layers.Conv2D(64, 3, strides=2, padding="same", activation="relu"),
        keras.layers.Conv2D(16, 1, padding="same")
    ])

# デコーダーの定義
def get_decoder():
    return keras.Sequential([
        keras.layers.Input(shape=(7, 7, 16)),
        keras.layers.Conv2DTranspose(64, 3, strides=2, padding="same", activation="relu"),
        keras.layers.Conv2DTranspose(32, 3, strides=2, padding="same", activation="relu"),
        keras.layers.Conv2DTranspose(1, 3, padding="same")
    ])

# VQ-VAEモデルの構築
def get_vqvae(latent_dim=16, num_embeddings=64):
    vq_layer = VectorQuantizer(num_embeddings, latent_dim, name="vector_quantizer")
    encoder = get_encoder()
    decoder = get_decoder()

    inputs = keras.Input(shape=(28, 28, 1))
    encoder_outputs = encoder(inputs)
    quantized_latents = vq_layer(encoder_outputs)
    reconstructions = decoder(quantized_latents)

    return keras.Model(inputs, reconstructions, name="vq_vae")

# モデルのコンパイルと学習
vqvae = get_vqvae()
vqvae.compile(optimizer=keras.optimizers.Adam(), loss="mse")

# MNISTデータセットの読み込みと前処理
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
x_train = x_train.astype("float32") / 255.0
x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.astype("float32") / 255.0
x_test = x_test.reshape(-1, 28, 28, 1)

# モデルの学習
vqvae.fit(x_train, x_train, epochs=10, batch_size=128, validation_data=(x_test, x_test))

VQ-VAEは、データの離散的な性質を活かしつつ、効率的な圧縮と高品質な生成を実現する革新的なモデルとして、機械学習の分野で重要な役割を果たしています。


いいなと思ったら応援しよう!