【ステップ5備忘録】ゼロから作るDeep Learning ❺【生成モデル編】ステップ5 EMアルゴリズムを読む
ゼロから作るDeep Learning ❺【生成モデル編】ステップ5ではEMアルゴリズムをテーマに、KLダイバージェンスと最尤推定の関係を踏まえて解説しています。この記事は本書を読むための予備知識としてお読み頂ければと思います。なお、本記事では機械学習のライブラリ「scikit-learn」を用いて EMアルゴリズムを実装します👨🎓
EMアルゴリズムとは?
潜在変数がある場合の最尤推定を行うための反復アルゴリズムです。EステップとMステップを交互に繰り返すことで、パラメータを最適化します。混合ガウス分布など、様々なモデルで利用されています。
EMアルゴリズムのメリットとデメリット
メリット:
* 複雑なモデルのパラメータを推定できる。
* 局所解に陥る可能性はあるが、多くの場合、良い解が得られる。
デメリット:
* 収束が遅い場合がある。
* 初期値に依存して結果が変わる可能性がある。
EMアルゴリズムの応用
クラスタリング: 混合ガウス分布を用いたソフトクラスタリング
自然言語処理: 潜在変数を導入したトピックモデル
画像処理: 画像のセグメンテーション
GaussianMixtureを使ったEMアルゴリズムの実装の注意
機械学習ライブラリのscikit-learnのを使うと簡単にEMアルゴリズムを実装する事が出来ます😃ただし、GaussianMixtureでは、EMアルゴリズムのEステップとMステップが明示的に示されているわけではありません。
なぜEステップとMステップが明示的に示されていないのか
scikit-learnの抽象化
scikit-learnのGaussianMixtureクラスは、EMアルゴリズムを実装したクラスであり、内部的にEステップとMステップを繰り返す処理が実装されています。そのため、ユーザーはこれらのステップを直接記述する必要がありません。
ユーザーフレンドリーなインターフェース
ユーザーは、モデルを作成し、データにフィットさせるというシンプルな操作でEMアルゴリズムを利用できます。
EMアルゴリズムのEステップとMステップ
EMアルゴリズムは、大きく分けて次の2つのステップを繰り返すアルゴリズムです。
Eステップ: 現在のモデルのパラメータのもとで、各データ点が各クラスタに属する確率(責任度)を計算します。
Mステップ: Eステップで計算された責任度に基づいて、各クラスタのパラメータ(平均、共分散行列など)を更新します。
scikit-learnのGaussianMixtureクラスの構造
fitメソッドの内部で、これらのステップが繰り返し実行されます。
ユーザーは、means_属性やcovariances_属性で、最終的に得られた各クラスタのパラメータにアクセスできます。
GaussianMixtureを使ったEMアルゴリズムの実装
import numpy as np
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
# データの生成
np.random.seed(0)
data = np.concatenate([np.random.normal(loc=-1, scale=0.5, size=300), np.random.normal(loc=1, scale=0.5, size=700)])
# ガウス混合モデルによる密度推定
gmm = GaussianMixture(n_components=2)
gmm.fit(data.reshape(-1, 1))
# 密度関数のプロット
x = np.linspace(-3, 3, 1000)
density = np.exp(gmm.score_samples(x.reshape(-1, 1)))
plt.figure(figsize=(8, 6))
plt.hist(data, bins=30, density=True, alpha=0.3)
plt.plot(x, density, color='red')
plt.show()
まとめ
EMアルゴリズムでは、KLダイバージェンスを最小化することで、モデルのパラメータを最適化していきます。本記事では、KLダイバージェンスの説明は割愛しています。なお「ゼロから作るDeep Learning ❺【生成モデル編】」では、KLダイバージェンスについての解説や、EMアルゴリズムのEステップとMステップの明示的な実装について解説がされています。ぜひ、詳しくは、「ゼロから作るDeep Learning ❺【生成モデル編】」をお読みになる事をお勧めします😃
この記事のコードは、下記のnoteの記事を参考にさせて頂きました。なお記事の内容、コーディングはGoogle Geminiにもアドバイスを頂いています😃