![見出し画像](https://assets.st-note.com/production/uploads/images/132891645/rectangle_large_type_2_97a41a82abaeb1a0d9c80da5ad1a38cd.png?width=1200)
StanとRでベイズ統計モデリングをPyMC Ver.5で写経~第10章「10.3.1 Nealの漏斗」
第10章「収束しない場合の対処法」
書籍の著者 松浦健太郎 先生
この記事は、テキスト第10章「収束しない場合の対処法」・10.3節「再パラメータ化」の 10.3.1項「Nealの漏斗」の PyMC5写経 を取り扱います。
はじめに
StanとRでベイズ統計モデリングの紹介
この記事は書籍「StanとRでベイズ統計モデリング」(共立出版、「テキスト」と呼びます)のベイズモデルを用いて、PyMC Ver.5で「実験的」に写経する翻訳的ドキュメンタリーです。
テキストは、2016年10月に発売され、ベイズモデリングのモデル式とプログラミングに関する丁寧な解説とモデリングの改善ポイントを網羅するチュートリアル「実践解説書」です。もちろん素晴らしいです!
「アヒル本」の愛称で多くのベイジアンに愛されてきた書籍です!
テキストに従ってStanとRで実践する予定でしたが、RのStan環境を整えることができませんでした(泣)
そこでこのシリーズは、テキストのベイズモデルをPyMC Ver.5に書き換えて実践します。
引用表記
この記事は、出典に記載の書籍に掲載された文章及びコードを引用し、適宜、掲載文章とコードを改変して書いています。
【出典】
「StanとRでベイズ統計モデリング」初版第13刷、著者 松浦健太郎、共立出版
記事中のイラストは、「かわいいフリー素材集いらすとや」さんのイラストをお借りしています。
ありがとうございます!
PyMC環境の準備
Anacondaを用いる環境構築とGoogle ColaboratoryでPyMCを動かす方法について、次の記事にまとめています。
「PyMCを動かすまでの準備」章をご覧ください。
10.3.1 Nealの漏斗
インポート
### インポート
# PyMC
import pymc as pm
import pytensor.tensor as pt
import arviz as az
# 描画
import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['font.family'] = 'Meiryo'
# ワーニング表示の抑制
import warnings
warnings.simplefilter('ignore')
![](https://assets.st-note.com/img/1709543012814-zjUaddGGcH.png)
データの読み込み・確認
データは未使用です。
再パラメータ化前のモデル
PyMCでモデル式10-8を実装します。
モデルの定義です。
### モデルの定義 ◆モデル式10-8 model10-8a.stan
with pm.Model() as model1:
### 事前分布
a = pm.Normal('a', mu=0, sigma=3)
r = pm.Normal('r', mu=0, sigma=pt.exp(a/2), shape=1000)
モデルの定義内容を見ます。
### モデルの表示
model1
【実行結果】
![](https://assets.st-note.com/img/1709542310521-QkDGuNboKR.png)
### モデルの可視化
pm.model_to_graphviz(model1)
【実行結果】
![](https://assets.st-note.com/img/1709542333861-HbWD3Mzq1y.png)
MCMCを実行します。
### 事後分布からのサンプリング 15秒
with model1:
idata1 = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.8,
nuts_sampler='numpyro', random_seed=1234)
【実行結果】省略
Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。
### r_hat>1.1の確認
# 設定
idata_in = idata1 # idata名
threshold = 1.1 # しきい値
# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())
【実行結果】
再パラメータ化前のモデルは収束していません。
![](https://assets.st-note.com/img/1709542376711-beCNxQ9fW5.png)
事後統計量を表示します。
### 推論データの要約統計情報の表示
var_names = ['a', 'r']
pm.summary(idata1, hdi_prob=0.95, var_names=var_names, round_to=3)
【実行結果】
![](https://assets.st-note.com/img/1709542411561-4cXHrHtnBc.png?width=1200)
トレースプロットを描画します。
### トレースプロットの表示 1分10秒
pm.plot_trace(idata1, compact=True, var_names=var_names)
plt.tight_layout();
【実行結果】
![](https://assets.st-note.com/img/1709542420947-yiRjE5v5s9.png?width=1200)
r[0] と a の散布図を描画します。
テキスト図10.5左に相当します。
なお、漏斗のような背景の描画は省略しました。
### r[0]とaの散布図の描画 ◆図10.5左
# 推論データよりMCMCサンプルデータの取り出し
a_samples1 = idata1.posterior.a.stack(sample=('chain', 'draw')).data
r_samples1 = idata1.posterior.r.stack(sample=('chain', 'draw')).data
# 描画領域の設定
fig, ax = plt.subplots(figsize=(5, 5))
# 散布図の描画
sns.scatterplot(x=r_samples1[0], y=a_samples1, alpha=0.5, ax=ax)
# 修飾
ax.set(xlabel='$r [0]$', ylabel='$a$')
ax.grid(lw=0.5);
【実行結果】
![](https://assets.st-note.com/img/1709542523195-fgc2ECC87k.png)
![](https://assets.st-note.com/img/1709543024203-28QqCsYbbs.png)
再パラメータ化後のモデル
PyMCでモデル式10-8の再パラメータ化版を実装します。
モデルの定義です。
aRaw と rRaw はスケールを固定した分布に従う設定にします。
aRaw と rRaw にスケールを掛けることで、元のパラメータa、r を表現しています。
### モデルの定義 ◆モデル式10-8 model10-8b.stan
with pm.Model() as model2:
### 事前分布
aRaw = pm.Normal('aRaw', mu=0, sigma=1)
rRaw = pm.Normal('rRaw', mu=0, sigma=1, shape=1000)
a = pm.Deterministic('a', 3.0 * aRaw)
r = pm.Deterministic('r', pt.exp(a/2) * rRaw)
モデルの定義内容を見ます。
### モデルの表示
model2
【実行結果】
![](https://assets.st-note.com/img/1709542724142-ShmQ3wH5us.png)
### モデルの可視化
pm.model_to_graphviz(model2)
【実行結果】
![](https://assets.st-note.com/img/1709542750891-5zOz5rWkQw.png)
MCMCを実行します。
### 事後分布からのサンプリング 10秒
with model2:
idata2 = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.8,
nuts_sampler='numpyro', random_seed=1234)
【実行結果】省略
Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。
### r_hat>1.1の確認
# 設定
idata_in = idata2 # idata名
threshold = 1.01 # しきい値
# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())
【実行結果】
収束条件を満たしました!
![](https://assets.st-note.com/img/1709542776536-5PYvLEXlnU.png)
事後統計量を表示します。
### 推論データの要約統計情報の表示
var_names = ['aRaw', 'a', 'rRaw', 'r']
pm.summary(idata2, hdi_prob=0.95, var_names=var_names, round_to=3)
【実行結果】
![](https://assets.st-note.com/img/1709542806134-gZOTSfm16q.png?width=1200)
トレースプロットを描画します。
### トレースプロットの表示 2分10秒
pm.plot_trace(idata2, compact=True, var_names=var_names)
plt.tight_layout();
【実行結果】
![](https://assets.st-note.com/img/1709542822456-yYvpbGJVa9.png?width=1200)
r[0] と a の散布図を描画します。
テキスト図10.5右に相当します。
なお、漏斗のような背景の描画は省略しました。
### r[0]とaの散布図の描画 ◆図10.5右
# 推論データよりMCMCサンプルデータの取り出し
a_samples2 = idata2.posterior.a.stack(sample=('chain', 'draw')).data
r_samples2 = idata2.posterior.r.stack(sample=('chain', 'draw')).data
# 描画領域の設定
fig, ax = plt.subplots(figsize=(5, 5))
# 散布図の描画
sns.scatterplot(x=r_samples2[0], y=a_samples2, alpha=0.5, ax=ax)
# 修飾
ax.set(xlabel='$r [0]$', ylabel='$a$')
ax.grid(lw=0.5);
【実行結果】
漏斗の形状に見えます。
![](https://assets.st-note.com/img/1709542844531-EHR3v4i8NN.png)
再パラメータ化前・後の散布図を横並びにして見てみます。
### r[0]とaの散布図の描画 ◆図10.5
## 描画処理
# 描画領域の設定
fig, ax = plt.subplots(1, 2, figsize=(10, 5), sharex=True, sharey=True)
# 散布図の描画
sns.scatterplot(x=r_samples1[0], y=a_samples1, alpha=0.5, ax=ax[0],
color='tab:blue')
# 修飾
ax[0].set(xlabel='$r\ [0]$', ylabel='$a$',
title='漏斗の先端付近からサンプリングできていない')
ax[0].grid(lw=0.5)
# 散布図の描画
sns.scatterplot(x=r_samples2[0], y=a_samples2, alpha=0.5, ax=ax[1],
color='tab:blue')
sns.scatterplot(x=r_samples1[0], y=a_samples1, alpha=0.1, ax=ax[1],
color='tab:orange')
# 修飾
ax[1].set(xlabel='$r\ [0]$', ylabel='$a$', title='再パラメータ化')
ax[1].text(s='←オレンジ\n 左図のデータ点を重ねた', x=10, y=-1, color='tab:red')
ax[1].grid(lw=0.5)
plt.tight_layout();
【実行結果】
再パラメータ化前は分布の一部からしか乱数生成できていないことが分かります。
![](https://assets.st-note.com/img/1709542895193-23E9eYlo4S.png?width=1200)
10.3.1 項は以上です。
![](https://assets.st-note.com/img/1709542994410-8GEnhUs463.png)
シリーズの記事
次の記事
前の記事
目次
ブログの紹介
note で7つのシリーズ記事を書いています。
ぜひ覗いていってくださいね!
1.のんびり統計
統計検定2級の問題集を手がかりにして、確率・統計をざっくり掘り下げるブログです。
雑談感覚で大丈夫です。ぜひ覗いていってくださいね。
統計検定2級公式問題集CBT対応版に対応しています。
Python、EXCELのサンプルコードの配布もあります。
2.実験!たのしいベイズモデリング1&2をPyMC Ver.5で
書籍「たのしいベイズモデリング」・「たのしいベイズモデリング2」の心理学研究に用いられたベイズモデルを PyMC Ver.5で描いて分析します。
この書籍をはじめ、多くのベイズモデルは R言語+Stanで書かれています。
PyMCの可能性を探り出し、手軽にベイズモデリングを実践できるように努めます。
身近なテーマ、イメージしやすいテーマですので、ぜひぜひPyMCで動かして、一緒に楽しみましょう!
3.実験!岩波データサイエンス1のベイズモデリングをPyMC Ver.5で
書籍「実験!岩波データサイエンスvol.1」の4人のベイジアンによるベイズモデルを PyMC Ver.5で描いて分析します。
この書籍はベイズプログラミングのイロハをざっくりと学ぶことができる良書です。
楽しくPyMCモデルを動かして、ベイズと仲良しになれた気がします。
みなさんもぜひぜひPyMCで動かして、一緒に遊んで学びましょう!
4.楽しい写経 ベイズ・Python等
ベイズ、Python、その他の「書籍の写経活動」の成果をブログにします。
主にPythonへの翻訳に取り組んでいます。
写経に取り組むお仲間さんのサンプルコードになれば幸いです🍀
5.RとStanではじめる心理学のための時系列分析入門 を PythonとPyMC Ver.5 で
書籍「RとStanではじめる心理学のための時系列分析入門」の時系列分析をPythonとPyMC Ver.5 で実践します。
この書籍には時系列分析のテーマが盛りだくさん!
時系列分析の懐の深さを実感いたしました。
大好きなPythonで楽しく時系列分析を学びます。
6.データサイエンスっぽいことを綴る
統計、データ分析、AI、機械学習、Pythonのコラムを不定期に綴っています。
統計・データサイエンス書籍にまつわる記事が多いです。
「統計」「Python」「数学とPython」「R」のシリーズが生まれています。
7.Python機械学習プログラミング実践記
書籍「Python機械学習プログラミング PyTorch & scikit-learn編」を学んだときのさまざまな思いを記事にしました。
この書籍は、scikit-learnとPyTorchの教科書です。
よかったらぜひ、お試しくださいませ。
最後までお読みいただきまして、ありがとうございました。