見出し画像

StanとRでベイズ統計モデリングをPyMC Ver.5で写経~第9章「9.4.2 simplex型」

第9章「一歩進んだ文法」

書籍の著者 松浦健太郎 先生


この記事は、テキスト第9章「一歩進んだ文法」・9.4節「パラメータの制約」の 9.4.2項「simplex型」の PyMC5写経 を取り扱います。

テキストは第9章で Stan の文法上の工夫を取り扱っています。
Stanの文法とPyMCの文法は異なっており、Stan の工夫点をPyMCに取り入れることができない、または、取り入れる方法が分からない場合があります。
したがいまして第9章では、私個人のスキルと相談して、PyMC化に意味を見いだせるテーマを選択して取り上げています。
選択の結果、9.4.1「lowerとupperによる範囲制限」、9.4.3「ordered型」、9.4.4「その他の制約」の写経は省略しました。

はじめに


StanとRでベイズ統計モデリングの紹介

この記事は書籍「StanとRでベイズ統計モデリング」(共立出版、「テキスト」と呼びます)のベイズモデルを用いて、PyMC Ver.5で「実験的」に写経する翻訳的ドキュメンタリーです。

テキストは、2016年10月に発売され、ベイズモデリングのモデル式とプログラミングに関する丁寧な解説とモデリングの改善ポイントを網羅するチュートリアル「実践解説書」です。もちろん素晴らしいです!
アヒル本」の愛称で多くのベイジアンに愛されてきた書籍です!

テキストに従ってStanとRで実践する予定でしたが、RのStan環境を整えることができませんでした(泣)
そこでこのシリーズは、テキストのベイズモデルをPyMC Ver.5に書き換えて実践します。

引用表記

この記事は、出典に記載の書籍に掲載された文章及びコードを引用し、適宜、掲載文章とコードを改変して書いています。
【出典】
「StanとRでベイズ統計モデリング」初版第13刷、著者 松浦健太郎、共立出版

記事中のイラストは、「かわいいフリー素材集いらすとや」さんのイラストをお借りしています。
ありがとうございます!

PyMC環境の準備

Anacondaを用いる環境構築とGoogle ColaboratoryでPyMCを動かす方法について、次の記事にまとめています。
「PyMCを動かすまでの準備」章をご覧ください。


9.4.2 simplex型


Stanのsimplex型はベクトル型の特別なケースであり、各要素の値が0~1の範囲となり、かつ、各要素の合計が1となる型です。
PyMCはデータ型の概念が薄い感じがします。
simplexに関しては、確率変数の引数 transform で用いることができる「pymc.distributions.transforms.simplex」がありますが、型を指定するものではありません。
今回のPyMC写経では、引数$${a}$$が全て1のディリクレ分布で代用いたします。

インポート

### インポート

# 数値・確率計算
import pandas as pd
import numpy as np

# PyMC
import pymc as pm
import arviz as az

# 描画
import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['font.family'] = 'Meiryo'

# ワーニング表示の抑制
import warnings
warnings.simplefilter('ignore')

データの読み込み・確認

サンプルコードのデータを読み込みます。

### データの読み込み ◆データファイル9.3 data-dice.txt
# Face: サイコロの出た目

data3 = pd.read_csv('./data/data-dice.txt')
print('data3.shape: ', data3.shape)
display(data3.head())

【実行結果】

サイコロの出目の割合を見てみましょう。

### 出目の割合の算出
data3['Face'].value_counts().sort_index() / data3['Face'].sum()

【実行結果】
2と4の目が出やすいようです。

出目をヒストグラムで可視化しましょう。

### ヒストグラムの描画
# ビンズの設定
bins = np.arange(0.5, 6.6, 1)
# 描画領域の設定
plt.figure(figsize=(5, 5))
# ヒストグラムの描画
sns.histplot(data=data3, x='Face', bins=bins, ec='white', kde=True)
# 修飾
plt.xlabel('サイコロの目')
plt.grid(lw=0.5);

【実行結果】

モデル式9-4 PyMCのモデル定義

PyMCでモデル式9-4を実装します。
モデルの定義です。
theta の事前分布にディリクレ分布を用いています。引数$${a}$$は$${[1, 1, 1, 1, 1, 1]}$$です。

### モデルの定義 ◆モデル式9-4 model9-4.stan

with pm.Model() as model3:
    
    ### データ関連定義
    ## coordの定義
    model3.add_coord('data', values=data3.index, mutable=True)
    model3.add_coord('dice', values=range(1, 7), mutable=True)
    ## dataの定義
    # 目的変数「Y - 1」 0始まり化
    Y = pm.ConstantData('Y', value=data3['Face'].values - 1, dims='data')

    ### 事前分布
    theta = pm.Dirichlet('theta', a=np.ones(6), dims='dice')

    ### 尤度関数
    obs = pm.Categorical('obs', p=theta, observed=Y, dims='data')

モデルの定義内容を見ます。

### モデルの表示
model3

【実行結果】

### モデルの可視化
pm.model_to_graphviz(model3)

【実行結果】
とてもシンプルなモデルです。

モデル式9-4 MCMCの実行と収束確認

MCMCを実行します。

### 事後分布からのサンプリング 10秒
with model3:
    idata3  = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.8,
                        nuts_sampler='numpyro', random_seed=1234)

【実行結果】省略

Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。

### r_hat>1.1の確認
# 設定
idata_in = idata3        # idata名
threshold = 1.01         # しきい値

# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())

【実行結果】
収束条件を満たしています。

事後統計量を表示します。

### 推論データの要約統計情報の表示
pm.summary(idata3, hdi_prob=0.95, round_to=3)

【実行結果】

トレースプロットを描画します。

### トレースプロットの表示
pm.plot_trace(idata3, compact=False)
plt.tight_layout();

【実行結果】

モデル式9-4 推定結果の解釈

事後分布の要約統計量を算出します。
算出関数を定義します。

### mean,sd,2.5%,25%,50%,75%,97.5%パーセンタイル点をデータフレーム化する関数の定義
def make_stats_df(y):
    probs = [2.5, 25, 50, 75, 97.5]
    columns = ['mean', 'sd'] + [str(s) + '%' for s in probs]
    quantiles = pd.DataFrame(np.percentile(y, probs, axis=0).T, index=y.columns)
    tmp_df = pd.concat([y.mean(axis=0), y.std(axis=0), quantiles], axis=1)
    tmp_df.columns=columns
    return tmp_df

要約統計量を算出します。

### 要約統計量の算出・表示
theta_samples_df = pd.DataFrame(
            idata3.posterior.theta.stack(sample=('chain', 'draw')).data.T,
            columns=[f'theta[{i+1}]' for i in range(6)])
display(make_stats_df(theta_samples_df).round(3))

【実行結果】
テキストに事後分布の推定値が掲載されていないので、PyMCモデルによる推論の適否は不明です。

モデル式9-5 PyMCのモデル定義

PyMCでモデル式9-5を実装します。
データの前処理を行います。
出目ごとの出現回数を算出します。

### モデル式9-5用データの作成
data4 = data3.value_counts().sort_index().to_frame().reset_index()
display(data4)

【実行結果】

モデルの定義です。
theta の事前分布にディリクレ分布を用いています。引数$${a}$$は$$[1, 1, 1, 1, 1, 1]$$です。
尤度関数が多項分布に変わります。

### モデルの定義 ◆モデル式9-5 model9-5.stan

with pm.Model() as model4:
    
    ### データ関連定義
    ## coordの定義
    model4.add_coord('data', values=data4.index, mutable=True)
    ## dataの定義
    # 目的変数 Y: サイコロの目ごとの回数
    Y = pm.ConstantData('Y', value=data4['count'].values, dims='data')

    ### 事前分布
    theta = pm.Dirichlet('theta', a=np.ones(6), dims='data')

    ### 尤度関数
    obs = pm.Multinomial('obs', n=data4['count'].sum(), p=theta, observed=Y,
                         dims='data')

モデルの定義内容を見ます。

### モデルの表示
model4

【実行結果】

### モデルの可視化
pm.model_to_graphviz(model4)

【実行結果】
こちらもとても、シンプルなモデルです。

モデル式9-5 MCMCの実行と収束確認

MCMCを実行します。

### 事後分布からのサンプリング 10秒
with model4:
    idata4  = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.8,
                        nuts_sampler='numpyro', random_seed=1234)

【実行結果】省略

Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。

### r_hat>1.1の確認
# 設定
idata_in = idata4        # idata名
threshold = 1.01         # しきい値

# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())

【実行結果】
収束条件を満たしています。

事後統計量を表示します。

### 推論データの要約統計情報の表示
pm.summary(idata4, hdi_prob=0.95, round_to=3)

【実行結果】

(参考:モデル式9-4)

トレースプロットを描画します。

### トレースプロットの表示
pm.plot_trace(idata4, compact=False)
plt.tight_layout();

【実行結果】

モデル式9-5 推定結果の解釈

事後分布の要約統計量を算出します。

### 要約統計量の算出・表示
theta_samples_df = pd.DataFrame(
            idata4.posterior.theta.stack(sample=('chain', 'draw')).data.T,
            columns=[f'theta[{i+1}]' for i in range(6)])
display(make_stats_df(theta_samples_df).round(3))

【実行結果】
テキストに事後分布の推定値が掲載されていないので、PyMCモデルによる推論の適否は不明です。

(参考:モデル式9-4)

9.4.2 項は以上です。


シリーズの記事

次の記事

前の記事

目次


ブログの紹介


note で7つのシリーズ記事を書いています。
ぜひ覗いていってくださいね!

1.のんびり統計

統計検定2級の問題集を手がかりにして、確率・統計をざっくり掘り下げるブログです。
雑談感覚で大丈夫です。ぜひ覗いていってくださいね。
統計検定2級公式問題集CBT対応版に対応しています。
Python、EXCELのサンプルコードの配布もあります。

2.実験!たのしいベイズモデリング1&2をPyMC Ver.5で

書籍「たのしいベイズモデリング」・「たのしいベイズモデリング2」の心理学研究に用いられたベイズモデルを PyMC Ver.5で描いて分析します。
この書籍をはじめ、多くのベイズモデルは R言語+Stanで書かれています。
PyMCの可能性を探り出し、手軽にベイズモデリングを実践できるように努めます。
身近なテーマ、イメージしやすいテーマですので、ぜひぜひPyMCで動かして、一緒に楽しみましょう!

3.実験!岩波データサイエンス1のベイズモデリングをPyMC Ver.5で

書籍「実験!岩波データサイエンスvol.1」の4人のベイジアンによるベイズモデルを PyMC Ver.5で描いて分析します。
この書籍はベイズプログラミングのイロハをざっくりと学ぶことができる良書です。
楽しくPyMCモデルを動かして、ベイズと仲良しになれた気がします。
みなさんもぜひぜひPyMCで動かして、一緒に遊んで学びましょう!

4.楽しい写経 ベイズ・Python等

ベイズ、Python、その他の「書籍の写経活動」の成果をブログにします。
主にPythonへの翻訳に取り組んでいます。
写経に取り組むお仲間さんのサンプルコードになれば幸いです🍀

5.RとStanではじめる心理学のための時系列分析入門 を PythonとPyMC Ver.5 で

書籍「RとStanではじめる心理学のための時系列分析入門」の時系列分析をPythonとPyMC Ver.5 で実践します。
この書籍には時系列分析のテーマが盛りだくさん!
時系列分析の懐の深さを実感いたしました。
大好きなPythonで楽しく時系列分析を学びます。

6.データサイエンスっぽいことを綴る

統計、データ分析、AI、機械学習、Pythonのコラムを不定期に綴っています。
統計・データサイエンス書籍にまつわる記事が多いです。
「統計」「Python」「数学とPython」「R」のシリーズが生まれています。

7.Python機械学習プログラミング実践記

書籍「Python機械学習プログラミング PyTorch & scikit-learn編」を学んだときのさまざまな思いを記事にしました。
この書籍は、scikit-learnとPyTorchの教科書です。
よかったらぜひ、お試しくださいませ。

最後までお読みいただきまして、ありがとうございました。

いいなと思ったら応援しよう!

この記事が参加している募集