見出し画像

StanとRでベイズ統計モデリングをPyMC Ver.5で写経~第10章「10.1.3 ラベルスイッチング」

第10章「収束しない場合の対処法」

書籍の著者 松浦健太郎 先生


この記事は、テキスト第10章「収束しない場合の対処法」・10.1節「パラメータの識別可能性」の 10.1.3項「ラベルスイッチング」の PyMC5写経 を取り扱います。

テキストは第10章で 収束に向けた工夫を取り扱っています。
座学や数式モデルのみの掲載項があれば、Stanコードの掲載項もあります。
PyMC化は主にStanコードが明示されているモデル式を対象にして実施します。

はじめに


StanとRでベイズ統計モデリングの紹介

この記事は書籍「StanとRでベイズ統計モデリング」(共立出版、「テキスト」と呼びます)のベイズモデルを用いて、PyMC Ver.5で「実験的」に写経する翻訳的ドキュメンタリーです。

テキストは、2016年10月に発売され、ベイズモデリングのモデル式とプログラミングに関する丁寧な解説とモデリングの改善ポイントを網羅するチュートリアル「実践解説書」です。もちろん素晴らしいです!
アヒル本」の愛称で多くのベイジアンに愛されてきた書籍です!

テキストに従ってStanとRで実践する予定でしたが、RのStan環境を整えることができませんでした(泣)
そこでこのシリーズは、テキストのベイズモデルをPyMC Ver.5に書き換えて実践します。

引用表記

この記事は、出典に記載の書籍に掲載された文章及びコードを引用し、適宜、掲載文章とコードを改変して書いています。
【出典】
「StanとRでベイズ統計モデリング」初版第13刷、著者 松浦健太郎、共立出版

記事中のイラストは、「かわいいフリー素材集いらすとや」さんのイラストをお借りしています。
ありがとうございます!

PyMC環境の準備

Anacondaを用いる環境構築とGoogle ColaboratoryでPyMCを動かす方法について、次の記事にまとめています。
「PyMCを動かすまでの準備」章をご覧ください。


10.1.3 ラベルスイッチング


Stanは変数 ordered 型を用いてラベルスイッチングを回避する方法が一般的のようです。
PyMCの類似概念には、確率変数の引数 transform で用いることができる「pymc.distributions.transforms.ordered」があります。
ただし中々の気難し屋さんなので、私はうまく使いこなすことができません。
今回のPyMC写経では、順序を保ちたい変数間の「差分変数」を使って、ラベルスイッチング回避を行います。

インポート

### インポート

# 数値・確率計算
import pandas as pd
import numpy as np
import scipy.stats as stats

# PyMC
import pymc as pm
import pytensor.tensor as pt
import arviz as az

# 描画
import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['font.family'] = 'Meiryo'

# ワーニング表示の抑制
import warnings
warnings.simplefilter('ignore')

データの作成・確認

次の2つの確率分布の混合分布から乱数データ mixed_norm を作成します。

$$
\begin{cases}
0.3 \times \text{Normal}\ (0,\ 2) \\
0.7 \times \text{Normal}\ (-4,\ 1) \\
\end{cases}
$$

### 混合正規分布の例

## データの作成
# 乱数生成機の設定
rng = np.random.default_rng(seed=1234)
# 2つの正規分布の乱数を300個と700個生成
norm1 = rng.normal(loc=0, scale=2, size=300)
norm2 = rng.normal(loc=-4, scale=1, size=700)
# 2つの正規分布乱数を結合して混合正規分布乱数にする
mixed_norm = np.concatenate([norm1, norm2])
# 生成した乱数 mixed_norm の表示
print('mixed_norm.shape: ', mixed_norm.shape)
print('mixed_norm[:20]:')
print(mixed_norm[:20])

【実行結果】

乱数データのヒストグラムを描画します。
テキスト図10.1に相当します。

## 描画処理 ◆図10.1

# x軸の値の設定
x_vals = np.linspace(-8, 8, 1001)

# 描画領域の設定
fig, ax = plt.subplots(figsize=(5, 4))
# 混合正規分布のKDE曲線付きヒストグラムの描画
sns.histplot(mixed_norm, kde=True, ec='white', stat='density', ax=ax,
             line_kws={'lw': 3})
ax.set(xlabel='$y$')
ax.grid(lw=0.5)
# 正規分布1の確率密度関数の描画
twinx1 = ax.twinx()
twinx1.plot(x_vals, stats.norm.pdf(x_vals, loc=0, scale=2),color='tab:red',
            lw=2, ls='--', label='$Normal(0,2)$')
twinx1.set(ylim=(0, 1), yticks=[])
# 正規分布2の確率密度関数の描画
twinx2 = ax.twinx()
twinx2.plot(x_vals, stats.norm.pdf(x_vals, loc=-4, scale=1), color='tab:orange',
            lw=2, ls='--', label='$Normal(-4,1)$')
twinx2.set(ylim=(0, 0.48), yticks=[])
# 全体修飾
fig.legend(bbox_to_anchor=(1.22, 0.9))
fig.suptitle('$Mixed\ Normal\ (y \mid 0.3,0,2,-4,1)$');

【実行結果】

テキストではStanモデルを作成しないで終了しています。
せっかくですので、PyMCでモデリングしてみましょう。

PyMCのモデル定義

PyMCでラベルスイッチングを回避するモデルを実装します。
モデルの定義です。
mu0 + muDelta = mu1 とし、muDelta を0 以上の値を取るように一様分布の下限を設定することで、mu0 ≦ mu1 を実現します。
また、尤度関数には混合正規分布 NormalMixture() を用いています。
パラメータ w が混合比率であり、推論してみましょう。

### モデルの定義

# 混合数
n_components = 2

# モデルの定義
with pm.Model() as model:
    
    ### データ関連定義
    ## coordの定義
    model.add_coord('data', values=range(len(mixed_norm)), mutable=True)

    ## dataの定義
    # 目的変数 Y
    Y = pm.ConstantData('Y', value=mixed_norm, dims='data')
    
    ### 事前分布
    # 混合分布の混合割合
    w = pm.Dirichlet('w', a=np.ones(n_components), shape=n_components)
    # 2つの正規分布の平均μ ※ラベルスイッチング防止にmuDeltaを採用
    mu0 = pm.Uniform('mu0', lower=-10, upper=10)
    muDelta = pm.Uniform('muDelta', lower=0, upper=10)
    mu1 = mu0 + muDelta
    mu = pm.Deterministic('mu', pt.stack([mu0, mu1]))
    # 2つの正規分布の標準偏差σ
    sigma = pm.Uniform('sigma', lower=0, upper=10, shape=n_components)

    ### 尤度関数 混合正規分布
    obs = pm.NormalMixture('obs', w=w, mu=mu, sigma=sigma, observed=Y,
                           dims='data')

モデルの定義内容を見ます。

### モデルの表示
model

【実行結果】

### モデルの可視化
pm.model_to_graphviz(model)

【実行結果】

MCMCの実行と収束確認

MCMCを実行します。

### 事後分布からのサンプリング 15秒
with model:
    idata  = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.8,
                       nuts_sampler='numpyro', random_seed=1234)

【実行結果】省略

Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。

### r_hat>1.1の確認
# 設定
idata_in = idata         # idata名
threshold = 1.01         # しきい値

# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())

【実行結果】
収束条件を満たしています。

事後統計量を表示します。

### 推論データの要約統計情報の表示
var_names = ['w', 'mu', 'sigma', 'mu0', 'muDelta']
pm.summary(idata, hdi_prob=0.95, var_names=var_names, round_to=3)

【実行結果】

トレースプロットを描画します。

### トレースプロットの表示
pm.plot_trace(idata, compact=True, var_names=var_names)
plt.tight_layout();

【実行結果】
パラメータ w, mu, sigma には2つの分布があります。
それぞれ「同じ色」=配列の順番が同じの chain(サンプルデータ)で構成されています。
これはラベルスイッチングが回避されていることを示しています。

推定結果の解釈

事後分布の要約統計量を算出します。
算出関数を定義します。

### mean,sd,2.5%,25%,50%,75%,97.5%パーセンタイル点をデータフレーム化する関数の定義
def make_stats_df(y):
    probs = [2.5, 25, 50, 75, 97.5]
    columns = ['mean', 'sd'] + [str(s) + '%' for s in probs]
    quantiles = pd.DataFrame(np.percentile(y, probs, axis=0).T, index=y.columns)
    tmp_df = pd.concat([y.mean(axis=0), y.std(axis=0), quantiles], axis=1)
    tmp_df.columns=columns
    return tmp_df

要約統計量を算出します。

### 要約統計量の算出・表示
w_samples_df = pd.DataFrame(
            idata.posterior.w.stack(sample=('chain', 'draw')).data.T,
            columns=[f'w[{i+1}]' for i in range(2)])
mu_samples_df = pd.DataFrame(
            idata.posterior.mu.stack(sample=('chain', 'draw')).data.T,
            columns=[f'mu[{i+1}]' for i in range(2)])
sigma_samples_df = pd.DataFrame(
            idata.posterior.sigma.stack(sample=('chain', 'draw')).data.T,
            columns=[f'sigma[{i+1}]' for i in range(2)])
param_df = pd.concat([w_samples_df, mu_samples_df, sigma_samples_df], axis=1)
display(make_stats_df(param_df).round(2))

【実行結果】

パラメータの事後分布の平均値を読み取ってみましょう。
配列1については、混合比率が約$${0.7}$$、正規分布の平均パラメータ mu が約$${-4}$$、標準偏差パラメータ sigma が約$${1.0}$$です。
また、配列2については、混合比率が約$${0.3}$$、正規分布の平均パラメータ mu が約$${0.1}$$、標準偏差パラメータ sigma が約$${2.0}$$です。
乱数生成に用いた次の混合分布とほぼ一致しています。

$$
\begin{cases}
0.3 \times \text{Normal}\ (0,\ 2) \\
0.7 \times \text{Normal}\ (-4,\ 1) \\
\end{cases}
$$

パラメータの事後分布を描画しましょう。

### パラメータの事後分布の描画
pm.plot_posterior(idata, hdi_prob=0.95, round_to=3,
                  var_names=['w', 'mu', 'sigma'], grid=(3, 2), figsize=(5, 7))
plt.tight_layout();

【実行結果】

10.1.3 項は以上です。


シリーズの記事

次の記事

前の記事

目次

ブログの紹介


note で7つのシリーズ記事を書いています。
ぜひ覗いていってくださいね!

1.のんびり統計

統計検定2級の問題集を手がかりにして、確率・統計をざっくり掘り下げるブログです。
雑談感覚で大丈夫です。ぜひ覗いていってくださいね。
統計検定2級公式問題集CBT対応版に対応しています。
Python、EXCELのサンプルコードの配布もあります。

2.実験!たのしいベイズモデリング1&2をPyMC Ver.5で

書籍「たのしいベイズモデリング」・「たのしいベイズモデリング2」の心理学研究に用いられたベイズモデルを PyMC Ver.5で描いて分析します。
この書籍をはじめ、多くのベイズモデルは R言語+Stanで書かれています。
PyMCの可能性を探り出し、手軽にベイズモデリングを実践できるように努めます。
身近なテーマ、イメージしやすいテーマですので、ぜひぜひPyMCで動かして、一緒に楽しみましょう!

3.実験!岩波データサイエンス1のベイズモデリングをPyMC Ver.5で

書籍「実験!岩波データサイエンスvol.1」の4人のベイジアンによるベイズモデルを PyMC Ver.5で描いて分析します。
この書籍はベイズプログラミングのイロハをざっくりと学ぶことができる良書です。
楽しくPyMCモデルを動かして、ベイズと仲良しになれた気がします。
みなさんもぜひぜひPyMCで動かして、一緒に遊んで学びましょう!

4.楽しい写経 ベイズ・Python等

ベイズ、Python、その他の「書籍の写経活動」の成果をブログにします。
主にPythonへの翻訳に取り組んでいます。
写経に取り組むお仲間さんのサンプルコードになれば幸いです🍀

5.RとStanではじめる心理学のための時系列分析入門 を PythonとPyMC Ver.5 で

書籍「RとStanではじめる心理学のための時系列分析入門」の時系列分析をPythonとPyMC Ver.5 で実践します。
この書籍には時系列分析のテーマが盛りだくさん!
時系列分析の懐の深さを実感いたしました。
大好きなPythonで楽しく時系列分析を学びます。

6.データサイエンスっぽいことを綴る

統計、データ分析、AI、機械学習、Pythonのコラムを不定期に綴っています。
統計・データサイエンス書籍にまつわる記事が多いです。
「統計」「Python」「数学とPython」「R」のシリーズが生まれています。

7.Python機械学習プログラミング実践記

書籍「Python機械学習プログラミング PyTorch & scikit-learn編」を学んだときのさまざまな思いを記事にしました。
この書籍は、scikit-learnとPyTorchの教科書です。
よかったらぜひ、お試しくださいませ。

最後までお読みいただきまして、ありがとうございました。

いいなと思ったら応援しよう!

この記事が参加している募集