見出し画像

StanとRでベイズ統計モデリングをPyMC Ver.5で写経~第12章「12.7 2次元の空間構造」

第12章「時間や空間を扱うモデル」

書籍の著者 松浦健太郎 先生


この記事は、テキスト第12章「時間や空間を扱うモデル」の 12.7節「2次元の空間構造」の PyMC5写経 を取り扱います。

今回は完走できず、データの一部分に限定して取り組みました。
PyMCを用いた2次元の空間構造のモデリングのベストプラクティスが分かっておりません・・・ううぅ

はじめに


StanとRでベイズ統計モデリングの紹介

この記事は書籍「StanとRでベイズ統計モデリング」(共立出版、「テキスト」と呼びます)のベイズモデルを用いて、PyMC Ver.5で「実験的」に写経する翻訳的ドキュメンタリーです。

テキストは、2016年10月に発売され、ベイズモデリングのモデル式とプログラミングに関する丁寧な解説とモデリングの改善ポイントを網羅するチュートリアル「実践解説書」です。もちろん素晴らしいです!
アヒル本」の愛称で多くのベイジアンに愛されてきた書籍です!

テキストに従ってStanとRで実践する予定でしたが、RのStan環境を整えることができませんでした(泣)
そこでこのシリーズは、テキストのベイズモデルをPyMC Ver.5に書き換えて実践します。

引用表記

この記事は、出典に記載の書籍に掲載された文章及びコードを引用し、適宜、掲載文章とコードを改変して書いています。
【出典】
「StanとRでベイズ統計モデリング」初版第13刷、著者 松浦健太郎、共立出版

記事中のイラストは、「かわいいフリー素材集いらすとや」さんのイラストをお借りしています。
ありがとうございます!

PyMC環境の準備

Anacondaを用いる環境構築とGoogle ColaboratoryでPyMCを動かす方法について、次の記事にまとめています。
「PyMCを動かすまでの準備」章をご覧ください。


12.7 2次元の空間構造


モデリングの準備

インポート

### インポート

# 数値・確率計算
import pandas as pd
import numpy as np

# PyMC
import pymc as pm
import pytensor.tensor as pt
import arviz as az

# 描画
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
plt.rcParams['font.family'] = 'Meiryo'

# ワーニング表示の抑制
import warnings
warnings.simplefilter('ignore')

データの読み込みと確認
サンプルコードのデータを読み込みます。
2次元で等間隔に隣接し合うプレート実験データです。
こちらはプレートの穴ごとの結果数値です。

### データの読み込み ◆データファイル12.1 data-2Dmesh.txt
# Y[i,j]:プレートのi行j列の穴の数値

data1 = pd.read_csv('./data/data-2Dmesh.txt', header=None)
print('data1.shape: ', data1.shape)
display(data1.head())

【実行結果】

こちらは各穴に施した処理のIDです。

### データの読み込み ◆データファイル12.2 data-2Dmesh-design.txt
# TID[i,j]:プレートのi行j列の穴に施した処理(1~96)

data2 = pd.read_csv('./data/data-2Dmesh-design.txt', header=None)
print('data2.shape: ', data2.shape)
display(data2.head())

【実行結果】

ヒートマップで可視化します。
テキスト図12.8に相当します。

### プレートの穴の配置と観測値Yの描画 ◆図12.8
plt.figure(figsize=(12, 5))
sns.heatmap(data1, cmap='Greens', annot=True, fmt='.1f',
            xticklabels=range(1, data1.shape[1]+1),
            yticklabels=range(1, data1.shape[0]+1),
            cbar_kws={'label': 'Y'})
plt.xlabel('Plate Column', fontsize=14)
plt.ylabel('Plate Row', fontsize=14)
plt.tight_layout();

【実行結果】
隣接し合う穴は似通った値に見えます。

PyMCのモデル定義

PyMCでモデル式12-13を実装します。
実直に r を配列化します。

PyMCのモデル定義 

### モデルの定義

## パラメータの設定
n_row = data1.shape[0]        # plateの行数
n_col = data1.shape[1]        # plateの列数
row_idx = list(range(n_row))  # plateの行インデックス
col_idx = list(range(n_col))  # plateの列インデックス

## モデルの定義
with pm.Model() as model:
    
    ### データ関連定義
    ## coordの定義
    # 観測データ・処理割付データのインデックス
    model.add_coord('row', values=data1.index, mutable=True)
    model.add_coord('col', values=data1.columns, mutable=True)
    # 処理パターン
    model.add_coord('treat', values=range(1, data2.max().max()+1),
                     mutable=True)
    
    ## dataの定義
    # 観測値
    y = pm.MutableData('y', value=data1.values, dims=('row', 'col'))
    # 処理割付
    T = pm.MutableData('T', value=data2.values, dims=('row', 'col'))

    ### 事前分布
    ## 標準偏差たち
    sigmaBeta = pm.Uniform('sigmaBeta', lower=0, upper=100)
    sigmaGamma = pm.Uniform('sigmaGamma', lower=0, upper=100)
    sigmaY = pm.Uniform('sigmaY', lower=0, upper=100)
 
    ## 処理の影響 β
    beta = pm.StudentT('beta', nu=6, mu=0, sigma=sigmaBeta, dims='treat')
    
    ## プレートの位置の影響 γ
    # 16x24のリストを準備
    r =  [[0 for j in range(n_col)] for i in range(n_row)]

    # 2次階差を取れない0行,1行 x 0列,1列の事前分布を定義
    for i in range(2):
        for j in range(2):
            r[i][j] = pm.Normal(name='r' + str(i) + '.' + str(j),
                                mu=0, sigma=sigmaGamma)
    # 2次階差を取れるセルの事前分布を定義
    for i in range(n_row):
        for j in range(n_col):
            if (i<2)&(j<2):     # 2次階差を取れないケース
                pass
            elif (i<2)&(j>=2):  # 列jのみ2次階差を取れるケース
                r[i][j] = pm.Normal(
                    name='r' + str(i) + '.' + str(j),
                    mu=2*r[i][j-1] - r[i][j-2],
                    sigma=sigmaGamma)
            elif (i>=2)&(j<2):  # 行iのみ2次階差を取れるケース
                r[i][j] = pm.Normal(
                    name='r' + str(i) + '.' + str(j),
                    mu=2*r[i-1][j] - r[i-2][j],
                    sigma=sigmaGamma)
            else:               # 行iと列jの両方が2次階差を取れるケース
                                # muを「/2」していいのか不明
                r[i][j] = pm.Normal(
                    name='r' + str(i) + '.' + str(j),
                    mu=(2*r[i][j-1] - r[i][j-2] + 2*r[i-1][j] - r[i-2][j]) / 2,
                    sigma=sigmaGamma)
    # rのリストからgammaを作成
    gamma = pt.stacklists(r)

    ### 尤度
    Y = pm.Normal('Y', mu=gamma + beta[T[row_idx]-1], sigma=sigmaY, observed=y,
                  dims=('row', 'col'))

モデルの定義内容を見ます。

### モデルの表示
model.basic_RVs

【実行結果】
途中で見切れています・・・

モデルの可視化は時間がかかるので省略します。

MCMCの実行と収束確認

MCMCを実行します。

### 事後分布からのサンプリング 5分あたりでエラーが発生、棄権
with model:
    idata = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.8,
                      nuts_sampler='numpyro', random_seed=1234)

【実行結果】
5分経過あたりで謎のエラーが発生します。
r の配列が大きくなるとこのエラーが発生するようです。

スモールデータで再実施

r の配列が大きくならないようにして、つまりデータを少なくしてみます。

### データのi=0~4, j=0~4を取り出し
data1_s = data1.iloc[:5, :5]
data2_s = data2.iloc[:5, :5]

取り出した箇所のヒートマップを描画します。

### プレートの穴の配置と観測値Yの描画
plt.figure(figsize=(12, 5))
sns.heatmap(data1_s, cmap='Greens', annot=True, fmt='.1f',
            xticklabels=range(1, data1_s.shape[1]+1),
            yticklabels=range(1, data1_s.shape[0]+1),
            vmin=data1.min().min(), vmax=data1.max().max(),
            cbar_kws={'label': 'Y'})
plt.xlabel('Plate Column', fontsize=14)
plt.ylabel('Plate Row', fontsize=14)
plt.tight_layout();

【実行結果】
5×5に大幅縮小です。

PyMCのモデル定義

モデルを定義します。

### モデルの定義

## パラメータの設定
n_row = data1_s.shape[0]      # plateの行数
n_col = data1_s.shape[1]      # plateの列数
row_idx = list(range(n_row))  # plateの行インデックス
col_idx = list(range(n_col))  # plateの列インデックス

## モデルの定義
with pm.Model() as model_s:
    
    ### データ関連定義
    ## coordの定義
    # 観測データ・処理割付データのインデックス
    model_s.add_coord('row', values=data1_s.index, mutable=True)
    model_s.add_coord('col', values=data1_s.columns, mutable=True)
    # 処理パターン
    model_s.add_coord('treat', values=range(1, data2_s.max().max()+1),
                      mutable=True)
    
    ## dataの定義
    # 観測値
    y = pm.MutableData('y', value=data1_s.values, dims=('row', 'col'))
    # 処理割付
    T = pm.MutableData('T', value=data2_s.values, dims=('row', 'col'))

    ### 事前分布
    ## 標準偏差たち
    sigmaBeta = pm.Uniform('sigmaBeta', lower=0, upper=100)
    sigmaGamma = pm.Uniform('sigmaGamma', lower=0, upper=100)
    sigmaY = pm.Uniform('sigmaY', lower=0, upper=100)
 
    ## 処理の影響 β
    beta = pm.StudentT('beta', nu=6, mu=0, sigma=sigmaBeta, dims='treat')
    
    ## プレートの位置の影響 γ
    # 16x24のリストを準備
    r =  [[0 for j in range(n_col)] for i in range(n_row)]

    # 2次階差を取れない0行,1行 x 0列,1列の事前分布を定義
    for i in range(2):
        for j in range(2):
            r[i][j] = pm.Normal(name='r' + str(i) + '.' + str(j),
                                mu=0, sigma=sigmaGamma)
    # 2次階差を取れるセルの事前分布を定義
    for i in range(n_row):
        for j in range(n_col):
            if (i<2)&(j<2):     # 2次階差を取れないケース
                pass
            elif (i<2)&(j>=2):  # 列jのみ2次階差を取れるケース
                r[i][j] = pm.Normal(
                    name='r' + str(i) + '.' + str(j),
                    mu=2*r[i][j-1] - r[i][j-2],
                    sigma=sigmaGamma)
            elif (i>=2)&(j<2):  # 行iのみ2次階差を取れるケース
                r[i][j] = pm.Normal(
                    name='r' + str(i) + '.' + str(j),
                    mu=2*r[i-1][j] - r[i-2][j],
                    sigma=sigmaGamma)
            else:               # 行iと列jの両方が2次階差を取れるケース
                                # muを「/2」していいのか不明
                r[i][j] = pm.Normal(
                    name='r' + str(i) + '.' + str(j),
                    mu=(2*r[i][j-1] - r[i][j-2] + 2*r[i-1][j] - r[i-2][j]) / 2,
                    sigma=sigmaGamma)
    # rのリストからgammaを作成
    gamma = pt.stacklists(r)

    ### 尤度
    Y = pm.Normal('Y', mu=gamma + beta[T[row_idx]-1], sigma=sigmaY, observed=y,
                  dims=('row', 'col'))

モデルの定義内容を見ます。

### モデルの表示
model_s

【実行結果】
r の変数がだいぶ少なくなりました。

### モデルの可視化
pm.model_to_graphviz(model_s)

【実行結果】
怪物のようなモデルです。

MCMCの実行と収束確認

MCMCを実行します。

### 事後分布からのサンプリング 30秒
with model_s:
    idata_s = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.8,
                        nuts_sampler='numpyro', random_seed=1234)

Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。

### r_hat>1.1の確認
# 設定
idata_in = idata_s        # idata名
threshold = 1.03          # しきい値

# しきい値を超えるR_hatの個数を表示
display((az.rhat(idata_in) > threshold).sum())

【実行結果】
収束条件を満たしています。

事後統計量を表示します。

### 推論データの要約統計情報の表示
var_names = ['sigmaBeta', 'sigmaGamma', 'sigmaY', 'beta']
pm.summary(idata_s, hdi_prob=0.95, var_names=var_names, round_to=3)

【実行結果】

トレースプロットを描画します。

### トレースプロットの表示
pm.plot_trace(idata_s, compact=True, var_names=var_names)
plt.tight_layout();

【実行結果】
3つの sigma はやや乱れ気味かもです。

推論結果の解釈

推論データを用いてテキスト図12.9左に相当するグラフを描画します。

### プレートの穴の配置とrの中央値の3次元プロット ◆図12.9左

## 描画用データの作成
# 推論データからrのMCMCサンプルデータを取り出し
r_samples = np.array(
    [(idata_s.posterior['r' + str(i) + '.' +str(j)]
      .stack(sample=('chain', 'draw'))).data
      for i in range(n_row) for j in range(n_col)]).reshape(n_row, n_col, 4000)

## 描画処理
# 描画領域の設定
fig = plt.figure()
# 3次元の描画領域の設定
ax = fig.add_subplot(111, projection='3d')
# サーフェスプロットの描画
g = ax.plot_surface(X=range(n_row), Y=range(n_col),
                    Z=np.median(r_samples, axis=2), cmap='Greens')
# カラーバーの描画
fig.colorbar(g, label='z')
# 修飾
ax.set(xlabel='Plate Row', ylabel='Plate Column');

【実行結果】
なんだか寂しいです。。。

【募集】

2次元の空間構造をPyMCでモデリングするコードを教えてください!!!

12.7 節は以上です。


シリーズの記事

次の記事

前の記事

目次


ブログの紹介


note で7つのシリーズ記事を書いています。
ぜひ覗いていってくださいね!

1.のんびり統計

統計検定2級の問題集を手がかりにして、確率・統計をざっくり掘り下げるブログです。
雑談感覚で大丈夫です。ぜひ覗いていってくださいね。
統計検定2級公式問題集CBT対応版に対応しています。
Python、EXCELのサンプルコードの配布もあります。

2.実験!たのしいベイズモデリング1&2をPyMC Ver.5で

書籍「たのしいベイズモデリング」・「たのしいベイズモデリング2」の心理学研究に用いられたベイズモデルを PyMC Ver.5で描いて分析します。
この書籍をはじめ、多くのベイズモデルは R言語+Stanで書かれています。
PyMCの可能性を探り出し、手軽にベイズモデリングを実践できるように努めます。
身近なテーマ、イメージしやすいテーマですので、ぜひぜひPyMCで動かして、一緒に楽しみましょう!

3.実験!岩波データサイエンス1のベイズモデリングをPyMC Ver.5で

書籍「実験!岩波データサイエンスvol.1」の4人のベイジアンによるベイズモデルを PyMC Ver.5で描いて分析します。
この書籍はベイズプログラミングのイロハをざっくりと学ぶことができる良書です。
楽しくPyMCモデルを動かして、ベイズと仲良しになれた気がします。
みなさんもぜひぜひPyMCで動かして、一緒に遊んで学びましょう!

4.楽しい写経 ベイズ・Python等

ベイズ、Python、その他の「書籍の写経活動」の成果をブログにします。
主にPythonへの翻訳に取り組んでいます。
写経に取り組むお仲間さんのサンプルコードになれば幸いです🍀

5.RとStanではじめる心理学のための時系列分析入門 を PythonとPyMC Ver.5 で

書籍「RとStanではじめる心理学のための時系列分析入門」の時系列分析をPythonとPyMC Ver.5 で実践します。
この書籍には時系列分析のテーマが盛りだくさん!
時系列分析の懐の深さを実感いたしました。
大好きなPythonで楽しく時系列分析を学びます。

6.データサイエンスっぽいことを綴る

統計、データ分析、AI、機械学習、Pythonのコラムを不定期に綴っています。
統計・データサイエンス書籍にまつわる記事が多いです。
「統計」「Python」「数学とPython」「R」のシリーズが生まれています。

7.Python機械学習プログラミング実践記

書籍「Python機械学習プログラミング PyTorch & scikit-learn編」を学んだときのさまざまな思いを記事にしました。
この書籍は、scikit-learnとPyTorchの教科書です。
よかったらぜひ、お試しくださいませ。

最後までお読みいただきまして、ありがとうございました。

この記事が気に入ったらサポートをしてみませんか?