見出し画像

StanとRでベイズ統計モデリングをPyMC Ver.5で写経~第11章「11.1 離散パラメータを扱うテクニック」

第11章「離散値をとるパラメータを使う」

書籍の著者 松浦健太郎 先生


この記事は、テキスト第11章「離散値をとるパラメータを使う」の 11.1節「離散パラメータを扱うテクニック」の PyMC5写経 を取り扱います。
Stanは周辺化消去という手法で離散パラメータに対応しますが、PyMCは離散パラメータを扱えるので、周辺化消去をする必要が無さそうです。

はじめに


StanとRでベイズ統計モデリングの紹介

この記事は書籍「StanとRでベイズ統計モデリング」(共立出版、「テキスト」と呼びます)のベイズモデルを用いて、PyMC Ver.5で「実験的」に写経する翻訳的ドキュメンタリーです。

テキストは、2016年10月に発売され、ベイズモデリングのモデル式とプログラミングに関する丁寧な解説とモデリングの改善ポイントを網羅するチュートリアル「実践解説書」です。もちろん素晴らしいです!
アヒル本」の愛称で多くのベイジアンに愛されてきた書籍です!

テキストに従ってStanとRで実践する予定でしたが、RのStan環境を整えることができませんでした(泣)
そこでこのシリーズは、テキストのベイズモデルをPyMC Ver.5に書き換えて実践します。

引用表記

この記事は、出典に記載の書籍に掲載された文章及びコードを引用し、適宜、掲載文章とコードを改変して書いています。
【出典】
「StanとRでベイズ統計モデリング」初版第13刷、著者 松浦健太郎、共立出版

記事中のイラストは、「かわいいフリー素材集いらすとや」さんのイラストをお借りしています。
ありがとうございます!

PyMC環境の準備

Anacondaを用いる環境構築とGoogle ColaboratoryでPyMCを動かす方法について、次の記事にまとめています。
「PyMCを動かすまでの準備」章をご覧ください。


11.1 離散パラメータを扱うテクニック


インポート

### インポート

# 数値・確率計算
import pandas as pd
import numpy as np

# PyMC
import pymc as pm
import pytensor.tensor as pt
import arviz as az

# 描画
import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['font.family'] = 'Meiryo'

# ワーニング表示の抑制
import warnings
warnings.simplefilter('ignore')

11.1.1 log_sum_exp関数

PyMCにも同様の関数 logsumexp() があります。
周辺化消去に使えるかどうかは分からない(勉強不足…)のですが、動作確認してみましょう。

x = pm.LogNormal.dist(mu=0, sigma=1)
xLogSumExp = pm.logsumexp(pt.log(0.5) + x)
xLogSumExp.eval()

【実行結果】

11.1.2 周辺化消去 ベルヌーイ分布に従う離散パラメータ

PyMCモデルでは周辺化消去を行わず、離散パラメータを通常のモデルで実装します。

コイントス試行のルール

ある学校に通う高校生を対象にした喫煙経験のアンケートをイメージします。
コイントスをして、表の場合に正直に喫煙経験をYes/Noで回答し、裏の場合に常にYesと回答します。

データの読み込み・確認

サンプルコードのデータを読み込みます。

### データの読み込み ◆データファイル11.1 data-coin.txt
# Y: 回答(1:Yes) コインを投げて表:喫煙経験を答える、裏:常にYesを答える

data1 = pd.read_csv('./data/data-coin.txt')
print('data1.shape: ', data1.shape)
display(data1.head())

【実行結果】
値 0 は No、値 1 は Yes です。

データの外観を確認します。
YesとNoを集計します。

### データの集計
data1.value_counts().sort_index().to_frame()

【実行結果】
半数の50がコイン裏のケースと仮定すると、喫煙率(予想)は$${20\%}$$($${=10 / (40+10)}$$)でしょうか。

円グラフで可視化します。

### データの可視化
plt.pie(data1.value_counts().sort_index(), startangle=90, counterclock=False,
        explode=[0.1, 0], labels=['No', 'Yes'], autopct='%1.1f%%',
        colors=['lightblue', 'lightpink']);

【実行結果】

PyMCのモデル定義

PyMCでモデル式11-1を実装します。周辺化消去はしません。
モデルの定義です。

### モデルの定義 ◆モデル式11-1 model11-1.stanと異なるモデル(周辺化消去しない)

with pm.Model() as model1:
    
    ### データ関連定義
    ## coordの定義
    model1.add_coord('data', values=data1.index, mutable=True)
    ## dataの定義
    # 目的変数 Y
    Y = pm.ConstantData('Y', value=data1['Y'].values, dims='data')

    ### 事前分布
    # コイン投げの結果 表:0, 裏:1
    coin = pm.Bernoulli('coin', p=0.5, dims='data')
    # 喫煙確率 q
    q = pm.Uniform('q', lower=0, upper=1)
    # Yesと回答する確率 θ coinが表:q, 裏:1
    theta = pt.stack([q, 1])

    ### 尤度関数
    obs = pm.Bernoulli('obs', p=theta[coin], observed=Y, dims='data')

モデルの定義内容を見ます。

### モデルの表示
model1

【実行結果】

### モデルの可視化
pm.model_to_graphviz(model1)

【実行結果】

MCMCの実行と収束確認

MCMCを実行します。
離散パラメータを含む場合、(私が知らないだけかもしれませんが)NUTSサンプラーはPyMC標準のサンプラーを使う必要があります。
また、変数 coin の初期値「オール0」を設定しました(設定しないとエラーになりました)。

### 事後分布からのサンプリング 50秒

# coinの初期値設定
initvals = {'coin': np.zeros(len(data1))}

# MCMCの実行 ※PyMC標準のNUTSサンプラーを使用
with model1:
    idata1  = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.95,
                        initvals=initvals, random_seed=1234)

【実行結果】

Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。

### r_hat>1.1の確認
# 設定
idata_in = idata1        # idata名
threshold = 1.01         # しきい値

# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())

【実行結果】
収束条件を満たしています。

事後統計量を表示します。

### 推論データの要約統計情報の表示
pm.summary(idata1, hdi_prob=0.95, round_to=3)

【実行結果】

トレースプロットを描画します。

### トレースプロットの表示
pm.plot_trace(idata1, compact=True)
plt.tight_layout();

【実行結果】

推定結果の解釈

事後分布の要約統計量を算出します。
算出関数を定義します。

### mean,sd,2.5%,25%,50%,75%,97.5%パーセンタイル点をデータフレーム化する関数の定義
def make_stats_df(y):
    probs = [2.5, 25, 50, 75, 97.5]
    columns = ['mean', 'sd'] + [str(s) + '%' for s in probs]
    quantiles = pd.DataFrame(np.percentile(y, probs, axis=0).T, index=y.columns)
    tmp_df = pd.concat([y.mean(axis=0), y.std(axis=0), quantiles], axis=1)
    tmp_df.columns=columns
    return tmp_df

喫煙確率 q の要約統計量を算出します。

### 要約統計量の算出・表示
vars = ['q']
param_samples = idata1.posterior[vars].to_dataframe().reset_index(drop=True)
display(make_stats_df(param_samples).round(3))

【実行結果】
テキストの$${q}$$の推定値の中央値と95%ベイズ信頼区間は$${0.20\ [0.03,\ 0.38]}$$とのことです。
このモデルでは$${0.20\ [0.04,\ 0.39]}$$となり、テキストとほぼ同じ結果になりました。

q の事後分布プロットを描画します。

### 事後分布プロットの描画
pm.plot_posterior(idata1, var_names=['q'], hdi_prob=0.95,
                  point_estimate='median', figsize=(4, 3), round_to=3);

【実行結果】

【別解】11.1.2 周辺化消去 ベルヌーイ分布に従う離散パラメータ

別解として、式(11.2)に示された2つのベルヌーイ分布の混合分布のモデルを試してみます。

PyMCのモデル定義

PyMCで式(11.2)を実装します。周辺化消去はしません。
モデルの定義です。

### モデルの定義(別解) ◆モデル式11-1 式(11.2)の混合分布を利用

with pm.Model() as model1_2:
    
    ### データ関連定義
    ## coordの定義
    model1_2.add_coord('data', values=data1.index, mutable=True)
    ## dataの定義
    # 目的変数 Y
    Y = pm.ConstantData('Y', value=data1['Y'].values, dims='data')

    ### 事前分布
    # 喫煙確率 q
    q = pm.Uniform('q', lower=0, upper=1)
    # 混合分布の混合割合 [表が出る確率, 裏が出る確率]
    w = [1/2, 1/2]
    # 混合分布の構成要素
    components  = [
        pm.Bernoulli.dist(p=q),  # 表が出たとき:喫煙確率のベルヌーイ分布
        pm.Bernoulli.dist(p=1),  # 裏が出たとき:確率=1のベルヌーイ分布
    ]

    ### 尤度関数
    obs = pm.Mixture('obs', w=w, comp_dists=components, observed=Y, dims='data')

モデルの定義内容を見ます。

### モデルの表示
model1_2

【実行結果】

### モデルの可視化
pm.model_to_graphviz(model1_2)

【実行結果】

MCMCの実行と収束確認

MCMCを実行します。
離散パラメータを含まなくなったので、NUTSサンプラーには numpyroを利用できます。
また、変数 coin の初期値は不要になりました(設定しなくてもエラーにならないです)。

### 事後分布からのサンプリング 10秒
with model1_2:
    idata1_2  = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.9,
                          nuts_sampler='numpyro', random_seed=1234)

【実行結果】省略

Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。

### r_hat>1.1の確認
# 設定
idata_in = idata1_2      # idata名
threshold = 1.01         # しきい値

# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())

【実行結果】
収束条件を満たしています。

事後統計量を表示します。

### 推論データの要約統計情報の表示
pm.summary(idata1_2, hdi_prob=0.95, round_to=3)

【実行結果】

トレースプロットを描画します。

### トレースプロットの表示
pm.plot_trace(idata1_2, compact=True)
plt.tight_layout();

【実行結果】

推定結果の解釈

喫煙確率 q の事後分布の要約統計量を算出します。

### 要約統計量の算出・表示
vars = ['q']
param_samples = idata1_2.posterior[vars].to_dataframe().reset_index(drop=True)
display(make_stats_df(param_samples).round(3))

【実行結果】
テキストの$${q}$$の推定値の中央値と95%ベイズ信頼区間は$${0.20\ [0.03,\ 0.38]}$$とのことです。
このモデルでは$${0.20\ [0.04,\ 0.39]}$$となり、テキストとほぼ同じ結果になりました。

(参考:前のモデルによる q の事後分布の要約統計量)

q の事後分布プロットを描画します。

### 事後分布プロットの描画
pm.plot_posterior(idata1_2, hdi_prob=0.95, point_estimate='median',
                  figsize=(4, 3), round_to=3);

【実行結果】

(参考:前のモデルによる q の事後分布)

11.1.2 周辺化消去 ポアソン分布に従う離散パラメータ

PyMCモデルでは周辺化消去を行わず、離散パラメータを通常のモデルで実装します。

コイントス試行のルール

ポアソン分布で生成した乱数の数だけコインを一度に投げて、表が出た数をカウントします。
この一連の乱数生成→コイン投げ→表の枚数のカウントの試行を100回実施します。

データの読み込み・確認

サンプルコードのデータを読み込みます。

### データの読み込み ◆データファイル11.2 data-poisson-binomial.txt
# Y: 表が出た数

data2 = pd.read_csv('./data/data-poisson-binomial.txt')
print('data2.shape: ', data2.shape)
display(data2.head())

【実行結果】

データの外観を確認します。
ヒストグラムを描画します。

### ヒストグラムの描画

# ビンの設定
bins = np.linspace(-0.5, 9.5, 11)
# ヒストグラム(KDE曲線付き)の描画
sns.histplot(data=data2, x='Y', bins=bins, ec='white', kde=True)
# 修飾
plt.xticks(range(0, 10))
plt.yticks(range(0, 22, 2))
plt.grid(lw=0.5)

【実行結果】

PyMCのモデル定義

PyMCでモデル式11-1を実装します。周辺化消去はしません。
モデルの定義です。

### モデルの定義 ◆モデル式11-2 model11-2.stanと異なるモデル(周辺化消去しない)

with pm.Model() as model2:
    
    ### データ関連定義
    ## coordの定義
    model2.add_coord('data', values=data2.index, mutable=True)
    ## dataの定義
    # 目的変数 Y coinの表が出た枚数
    Y = pm.ConstantData('Y', value=data2['Y'].values, dims='data')

    ### 事前分布
    # 乱数を発生するポアソン分布の未知のパラメータλ
    lam = pm.Uniform('lam', lower=0, upper=100)
    # ポアソン分布から発生させた乱数 m
    m = pm.Poisson('m', mu=lam, dims='data')
    
    ### 尤度関数 m回のコイントス(表の出る確率0.5)で表が出た枚数
    obs = pm.Binomial('obs', n=m, p=0.5, observed=Y, dims='data')

モデルの定義内容を見ます。

### モデルの表示
model2

【実行結果】

### モデルの可視化
pm.model_to_graphviz(model2)

【実行結果】

MCMCの実行と収束確認

MCMCを実行します。
離散パラメータを含む場合、(私が知らないだけかもしれませんが)NUTSサンプラーはPyMC標準のサンプラーを使う必要があります。

### 事後分布からのサンプリング 50秒 ※PyMC標準のNUTSサンプラーを使用
with model2:
    idata2  = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.8,
                        random_seed=1234)

【実行結果】

Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。

### r_hat>1.1の確認
# 設定
idata_in = idata2        # idata名
threshold = 1.02         # しきい値

# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())

【実行結果】
収束条件を満たしています。

事後統計量を表示します。

### 推論データの要約統計情報の表示
pm.summary(idata2, hdi_prob=0.95, round_to=3)

【実行結果】

トレースプロットを描画します。

### トレースプロットの表示
pm.plot_trace(idata2, compact=True)
plt.tight_layout();

【実行結果】

推定結果の解釈

乱数の平均値 lam($${\lambda}$$)の事後分布の要約統計量を算出します。

### 要約統計量の算出・表示
vars = ['q']
param_samples = idata1.posterior[vars].to_dataframe().reset_index(drop=True)
display(make_stats_df(param_samples).round(3))

【実行結果】
テキストの$${\lambda}$$の推定値の中央値と95%ベイズ信頼区間は$${9.57\ [8.77,\ 10.45]}$$とのことです。
このモデルでは$${9.61\ [8.77,\ 10.48]}$$となり、テキストとほぼ同じ結果になりました。

lam($${\lambda}$$)の事後分布プロットを描画します。

### 事後分布プロットの描画
pm.plot_posterior(idata2, var_names=['lam'], hdi_prob=0.95,
                  point_estimate='median', figsize=(4, 3), round_to=3);

【実行結果】

11.1.3 公式の活用

11.1.2 のポアソン分布の例題の変形モデルです。

PyMCのモデル定義

PyMCでモデル式11-4を実装します。
モデルの定義です。

### モデルの定義 ◆モデル式11-4 model11-2b.stan

with pm.Model() as model3:
    
    ### データ関連定義
    ## coordの定義
    model3.add_coord('data', values=data2.index, mutable=True)
    ## dataの定義
    # 目的変数 Y coinの表が出た枚数
    Y = pm.ConstantData('Y', value=data2['Y'].values, dims='data')

    ### 事前分布
    # 乱数を発生するポアソン分布の未知のパラメータλ
    lam = pm.Uniform('lam', lower=0, upper=100)
    
    ### 尤度関数 m回のコイントス(表の出る確率0.5)で表が出た枚数
    obs = pm.Poisson('obs', mu=lam * 0.5, observed=Y, dims='data')

モデルの定義内容を見ます。

### モデルの表示
model3

【実行結果】

### モデルの可視化
pm.model_to_graphviz(model3)

【実行結果】

MCMCの実行と収束確認

MCMCを実行します。
離散パラメータを含まなくなったので、NUTSサンプラーには numpyroを利用できます。

### 事後分布からのサンプリング 5秒
with model3:
    idata3  = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.8,
                        nuts_sampler='numpyro', random_seed=1234)

【実行結果】省略

Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。

### r_hat>1.1の確認
# 設定
idata_in = idata3        # idata名
threshold = 1.01         # しきい値

# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())

【実行結果】
収束条件を満たしています。

事後統計量を表示します。

### 推論データの要約統計情報の表示
pm.summary(idata3, hdi_prob=0.95, round_to=3)

【実行結果】

トレースプロットを描画します。

### トレースプロットの表示
pm.plot_trace(idata3, compact=True)
plt.tight_layout();

【実行結果】

推定結果の解釈

乱数の平均値 lam($${\lambda}$$)の事後分布の要約統計量を算出します。

### 要約統計量の算出・表示
vars = ['lam']
param_samples = idata3.posterior[vars].to_dataframe().reset_index(drop=True)
display(make_stats_df(param_samples).round(3))

【実行結果】
テキストのモデル式11-2の$${\lambda}$$の推定値の中央値と95%ベイズ信頼区間は$${9.57\ [8.77,\ 10.45]}$$とのことです。
このモデルでは$${9.58\ [8.73,\ 10.50]}$$となり、テキストとほぼ同じ結果になりました。

(参考:11.1.2のモデルによる$${\lambda}$$の事後分布の要約統計量)

lam($${\lambda}$$)の事後分布プロットを描画します。

### 事後分布プロットの描画
pm.plot_posterior(idata3, var_names=['lam'], hdi_prob=0.95,
                  point_estimate='median', figsize=(4, 3), round_to=3);

【実行結果】

(参考:11.1.2のモデルによる$${\lambda}$$の事後分布)

11.1 節は以上です。


シリーズの記事

次の記事

前の記事

目次


ブログの紹介


note で7つのシリーズ記事を書いています。
ぜひ覗いていってくださいね!

1.のんびり統計

統計検定2級の問題集を手がかりにして、確率・統計をざっくり掘り下げるブログです。
雑談感覚で大丈夫です。ぜひ覗いていってくださいね。
統計検定2級公式問題集CBT対応版に対応しています。
Python、EXCELのサンプルコードの配布もあります。

2.実験!たのしいベイズモデリング1&2をPyMC Ver.5で

書籍「たのしいベイズモデリング」・「たのしいベイズモデリング2」の心理学研究に用いられたベイズモデルを PyMC Ver.5で描いて分析します。
この書籍をはじめ、多くのベイズモデルは R言語+Stanで書かれています。
PyMCの可能性を探り出し、手軽にベイズモデリングを実践できるように努めます。
身近なテーマ、イメージしやすいテーマですので、ぜひぜひPyMCで動かして、一緒に楽しみましょう!

3.実験!岩波データサイエンス1のベイズモデリングをPyMC Ver.5で

書籍「実験!岩波データサイエンスvol.1」の4人のベイジアンによるベイズモデルを PyMC Ver.5で描いて分析します。
この書籍はベイズプログラミングのイロハをざっくりと学ぶことができる良書です。
楽しくPyMCモデルを動かして、ベイズと仲良しになれた気がします。
みなさんもぜひぜひPyMCで動かして、一緒に遊んで学びましょう!

4.楽しい写経 ベイズ・Python等

ベイズ、Python、その他の「書籍の写経活動」の成果をブログにします。
主にPythonへの翻訳に取り組んでいます。
写経に取り組むお仲間さんのサンプルコードになれば幸いです🍀

5.RとStanではじめる心理学のための時系列分析入門 を PythonとPyMC Ver.5 で

書籍「RとStanではじめる心理学のための時系列分析入門」の時系列分析をPythonとPyMC Ver.5 で実践します。
この書籍には時系列分析のテーマが盛りだくさん!
時系列分析の懐の深さを実感いたしました。
大好きなPythonで楽しく時系列分析を学びます。

6.データサイエンスっぽいことを綴る

統計、データ分析、AI、機械学習、Pythonのコラムを不定期に綴っています。
統計・データサイエンス書籍にまつわる記事が多いです。
「統計」「Python」「数学とPython」「R」のシリーズが生まれています。

7.Python機械学習プログラミング実践記

書籍「Python機械学習プログラミング PyTorch & scikit-learn編」を学んだときのさまざまな思いを記事にしました。
この書籍は、scikit-learnとPyTorchの教科書です。
よかったらぜひ、お試しくださいませ。

最後までお読みいただきまして、ありがとうございました。

いいなと思ったら応援しよう!

この記事が参加している募集