StanとRでベイズ統計モデリングをPyMC Ver.5で写経~第12章「練習問題」
第12章「時間や空間を扱うモデル」
書籍の著者 松浦健太郎 先生
この記事は、テキスト第12章「時間や空間を扱うモデル」の 「練習問題」の PyMC5写経 を取り扱います。
なお、次の練習問題は写経を省略いたしました。
練習問題(3):2次元の空間構造のモデリングが難しいため
練習問題(4):2次元の空間構造のモデリングが難しいため
はじめに
StanとRでベイズ統計モデリングの紹介
この記事は書籍「StanとRでベイズ統計モデリング」(共立出版、「テキスト」と呼びます)のベイズモデルを用いて、PyMC Ver.5で「実験的」に写経する翻訳的ドキュメンタリーです。
テキストは、2016年10月に発売され、ベイズモデリングのモデル式とプログラミングに関する丁寧な解説とモデリングの改善ポイントを網羅するチュートリアル「実践解説書」です。もちろん素晴らしいです!
「アヒル本」の愛称で多くのベイジアンに愛されてきた書籍です!
テキストに従ってStanとRで実践する予定でしたが、RのStan環境を整えることができませんでした(泣)
そこでこのシリーズは、テキストのベイズモデルをPyMC Ver.5に書き換えて実践します。
引用表記
この記事は、出典に記載の書籍に掲載された文章及びコードを引用し、適宜、掲載文章とコードを改変して書いています。
【出典】
「StanとRでベイズ統計モデリング」初版第13刷、著者 松浦健太郎、共立出版
記事中のイラストは、「かわいいフリー素材集いらすとや」さんのイラストをお借りしています。
ありがとうございます!
PyMC環境の準備
Anacondaを用いる環境構築とGoogle ColaboratoryでPyMCを動かす方法について、次の記事にまとめています。
「PyMCを動かすまでの準備」章をご覧ください。
12章 練習問題
インポート
### インポート
# 数値・確率計算
import pandas as pd
import numpy as np
# PyMC
import pymc as pm
import arviz as az
# 描画
import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['font.family'] = 'Meiryo'
# ワーニング表示の抑制
import warnings
warnings.simplefilter('ignore')
練習問題(1)
モデル式12-2のシミュレーションのコードを書きます。
Yのシミュレーション値を算出してグラフを描画する関数を定義します。
### Yを算出して折れ線グラフを描画する関数の定義
def calc_plot_Y(N, T, sigmaMu, sigmaY, seed=1234):
## 設定
# 乱数生成器の設定
rng = np.random.default_rng(seed=seed)
# Yの乱数サンプルを格納する一時リストの準備
Ys = []
## T×NのY乱数サンプルを作成
for n in range(N):
## muの算出
# muのリストの初期化
mu = [0] * T
# mu0の正規分布乱数の作成
mu[0] = rng.normal(loc=10, scale=sigmaMu, size=1)
# mu1~の正規分布乱数の作成
for t in range(1, T):
mu[t] = rng.normal(loc=mu[t-1], scale=sigmaMu, size=1)
## Yの算出
# Yの正規分布乱数の作成
Y = rng.normal(loc=mu, scale=sigmaY)
# 一時リストに追加
Ys.append(Y)
## Ysをpandasのデータフレーム化
Y_df = pd.DataFrame(np.array(Ys).squeeze().T, index=range(1, T+1),
columns=range(1, N+1))
## Yの描画
# 描画領域の設定
plt.figure(figsize=(7, 3))
# Yの描画
sns.lineplot(Y_df, lw=0.8)
# 修飾
plt.title(rf'$\sigma_{{\mu}}$={sigmaMu}, $\sigma_Y$={sigmaY}')
plt.xlabel('Time')
plt.ylabel('Y')
plt.grid(lw=0.5)
plt.legend(bbox_to_anchor=(1, 1), title='Trial')
plt.show()
《解答》
$${\sigma_{\mu}>\sigma_Y}$$のケースをシミュレーションします。
### sigmaMuが大、sigmaYが小のケース
# 設定
sigmaMu = 2
sigmaY = 0.1
N = 5
T = 50
# Yの算出とグラフ描画
calc_plot_Y(N, T, sigmaMu, sigmaY)
【実行結果】
5つのシミュレーションの線はランダムに広がっています。
《解答》
$${\sigma_{\mu}<\sigma_Y}$$のケースをシミュレーションします。
### sigmaMuが小、sigmaYが大のケース
# 設定
sigmaMu = 0.1
sigmaY = 2
N = 5
T = 50
# Yの算出とグラフ描画
calc_plot_Y(N, T, sigmaMu, sigmaY)
【実行結果】
5つのシミュレーションの線は類似する値をとって同じような軌跡を描いています。
練習問題(2)
model12-6.stanに利用したデータを読み込みます。
### データの読み込み ◆data-ss2.txt
# X:日付(四半期), Y:季節ものの販売数[千個]
data = pd.read_csv('./data/data-ss2.txt')
print('data.shape: ', data.shape)
display(data.head())
【実行結果】
モデルの定義です。
8期先までの予測期間の Y の値に欠損値を設定して、このモデルで欠損値を含めた予測を行います。
### モデルの定義 ◆モデル式12-6 model12-6.stan
# 季節調整項のARパラメータの設定
period = 4 # 季節成分の周期
order = period - 1 # ARの次数
rhos = np.ones(order) * -1 # ARのρパラメータ
# 予測期間の設定
pred_periods = 8
# 予測に用いるYの作成(観測値+予測期間の欠損値)
y4pred = np.concatenate([data['Y'].values, np.repeat(np.nan, pred_periods)])
# モデルの定義
with pm.Model() as model:
### データ関連定義
# coordの定義
model.add_coord('data', values=range(len(y4pred)), mutable=True)
### 事前分布
# 標準偏差
sigmaMu = pm.Uniform('sigmaMu', lower=0, upper=10) # トレンド項
sigmaSeason = pm.Uniform('sigmaSeason', lower=0, upper=10) # 季節調整項
sigmaY = pm.Uniform('sigmaY', lower=0, upper=10) # 観測ノイズ
# トレンド項
init_dist_trend = pm.Normal.dist(mu=data['Y'].mean(), sigma=10)
mu = pm.AR('mu', rho=1, sigma=sigmaMu, constant=False,
init_dist=init_dist_trend, ar_order=1, dims='data')
# 季節調整項
init_dist_season = pm.Normal.dist(mu=0, sigma=5)
season = pm.AR('season', rho=rhos, sigma=sigmaSeason, constant=False,
init_dist=init_dist_season, ar_order=order, dims='data')
# yMean
yMean = pm.Deterministic('yMean', mu + season, dims='data')
### 尤度
obs = pm.Normal('obs', mu=yMean, sigma=sigmaY, observed=y4pred, dims='data')
モデルの定義内容を見ます。
### モデルの表示
model.basic_RVs
【実行結果】
### モデルの可視化
pm.model_to_graphviz(model)
【実行結果】
MCMCを実行します。
### 事後分布からのサンプリング 2分50秒
with model:
idata = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.9,
random_seed=1234)
【実行結果】
Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。
### r_hat>1.1の確認
# 設定
idata_in = idata # idata名
threshold = 1.05 # しきい値
# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())
【実行結果】
収束条件を満たしています。
事後統計量を表示します。
### 推論データの要約統計情報の表示
pm.summary(idata, hdi_prob=0.95, round_to=3)
【実行結果】
トレースプロットを描画します。
### トレースプロットの表示
pm.plot_trace(idata, compact=True)
plt.tight_layout();
【実行結果】
欠損値を含めたことが原因なのか、発散しています。
ただ、12.2節「季節調整項」で実践した欠損値を含めないモデルでは発散しなかったので、MCMCサンプルデータは適切だと仮定して、分析を続行します。
事後統計量を算出します。
事後統計量算出関数の定義から。
### median, 10%, 90%パーセンタイル点をデータフレーム化する関数の定義
def make_stats_df(y):
probs = [50, 10, 90]
columns = ['median', '10%CI', '90%CI']
quantiles = pd.DataFrame(np.percentile(y, probs, axis=0).T, index=y.columns)
quantiles.columns = columns
return quantiles
事後統計量を算出します。
### 要約統計量の算出・表示
vars = ['sigmaMu', 'sigmaSeason', 'sigmaY']
param_samples = idata.posterior[vars].to_dataframe().reset_index(drop=True)
display(make_stats_df(param_samples).round(3))
【実行結果】
《解答》
Yの観測値と推論値を描画します。
### Yの事後分布・予測分布の描画
## 描画用データの作成
# 推論データからobsのMCMCサンプルデータを取り出し
y_samples = idata.posterior.obs.stack(sample=('chain', 'draw')).data
# yMeanの中央値の算出
y_median = np.median(y_samples, axis=1)
# yMeanの80%CI, 50%CIの算出
y_80ci = np.quantile(y_samples, q=[0.10, 0.90], axis=1)
y_50ci = np.quantile(y_samples, q=[0.25, 0.75], axis=1)
# 予測期間を含めたx軸の値の算出
pred_times = range(1, len(y4pred) + 1)
## 描画処理
# 描画領域の設定
plt.figure(figsize=(8, 4))
# Yの観測値の描画
plt.plot(data.X, data.Y, 'o', color='royalblue', label='$Y$ の観測値')
# yMeanの中央値の描画
plt.plot(pred_times, y_median, color='tab:red', label='$Y$の予測値:中央値')
# yMeanの50%CIの描画
plt.fill_between(pred_times, y_50ci[0], y_50ci[1], color='tomato',
alpha=0.4, label='$Y$の予測値:50%CI')
# yMeanの80%CIの描画
plt.fill_between(pred_times, y_80ci[0], y_80ci[1], color='tomato',
alpha=0.2, label='$Y$の予測値:80%CI')
# 観測値の最後の期の垂直線の描画
plt.axvline(len(data), color='black', lw=0.8, ls='--')
# 修飾
plt.xlabel('Time(Quarter)')
plt.ylabel('Y')
plt.title(r'$Y$ の観測値と推論値')
plt.legend(bbox_to_anchor=(1, 1))
plt.grid(lw=0.5);
【実行結果】
直前四半期の観測値と同じような傾向になったように見えます。
Yの観測値と推論値を描画します。
推論値には Yの従う正規分布の平均パラメータ yMean を使います。
### yMeanの事後分布・予測分布の描画
## 描画用データの作成
# 推論データからyMeanのMCMCサンプルデータを取り出し
yMean_samples = idata.posterior.yMean.stack(sample=('chain', 'draw')).data
# yMeanの中央値の算出
yMean_median = np.median(yMean_samples, axis=1)
# yMeanの80%CI, 50%CIの算出
yMean_80ci = np.quantile(yMean_samples, q=[0.10, 0.90], axis=1)
yMean_50ci = np.quantile(yMean_samples, q=[0.25, 0.75], axis=1)
# 予測期間を含めたx軸の値の算出
pred_times = range(1, len(y4pred) + 1)
## 描画処理
# 描画領域の設定
plt.figure(figsize=(8, 4))
# Yの観測値の描画
plt.plot(data.X, data.Y, 'o', color='royalblue', label='$Y$ の観測値')
# yMeanの中央値の描画
plt.plot(pred_times, yMean_median, color='tab:red', label='$Y_{mean}$:中央値')
# yMeanの50%CIの描画
plt.fill_between(pred_times, yMean_50ci[0], yMean_50ci[1], color='tomato',
alpha=0.4, label='$Y_{mean}$:50%CI')
# yMeanの80%CIの描画
plt.fill_between(pred_times, yMean_80ci[0], yMean_80ci[1], color='tomato',
alpha=0.2, label='$Y_{mean}$:80%CI')
# 観測値の最後の期の垂直線の描画
plt.axvline(len(data), color='black', lw=0.8, ls='--')
# 修飾
plt.xlabel('Time(Quarter)')
plt.ylabel('Y')
plt.title(r'$Y$ の観測値とYの平均値$Y_{mean}$の推論値')
plt.legend(bbox_to_anchor=(1, 1))
plt.grid(lw=0.5);
【実行結果】
直前四半期の観測値と同じような傾向になったように見えます。
トレンド項$${\mu}$$を描画します。
### μの描画
## 描画用データの作成
# 推論データからμのMCMCサンプルデータを取り出し
mu_samples = idata.posterior.mu.stack(sample=('chain', 'draw')).data
# μの中央値の算出
mu_median = np.median(mu_samples, axis=1)
# μの80%CI, 50%CIの算出
mu_80ci = np.quantile(mu_samples, q=[0.10, 0.90], axis=1)
mu_50ci = np.quantile(mu_samples, q=[0.25, 0.75], axis=1)
## 描画処理
# 描画領域の設定
plt.figure(figsize=(8, 4))
# Yの観測値の描画
plt.plot(data.X, data.Y, '-o', color='royalblue', ms=5, lw=0.9,
label='$Y$ の観測値')
# μの中央値の描画
plt.plot(pred_times, mu_median, color='tab:red', label='$\mu$:中央値')
# μの50%CIの描画
plt.fill_between(pred_times, mu_50ci[0], mu_50ci[1], color='tomato', alpha=0.5,
label='$\mu$:50%CI')
# μの80%CIの描画
plt.fill_between(pred_times, mu_80ci[0], mu_80ci[1], color='tomato', alpha=0.2,
label='$\mu$:80%CI')
# 観測値の最後の期の垂直線の描画
plt.axvline(len(data), color='black', lw=0.8, ls='--')
# 修飾
plt.xlabel('Time(Quarter)')
plt.ylabel('Y')
plt.title(r'$Y$ の観測値とトレンド項 $\mu$ の推論値')
plt.legend(bbox_to_anchor=(1, 1))
plt.grid(lw=0.5);
【実行結果】
トレンドは横ばいの予測です。
季節調整項$${season}$$を描画します。
### 季節調整項の描画
## 描画用データの作成
# 推論データからseasonのMCMCサンプルデータを取り出し
season_samples = idata.posterior.season.stack(sample=('chain', 'draw')).data
# seasonの中央値の算出
season_median = np.median(season_samples, axis=1)
# seasonの80%CI, 50%CIの算出
season_80ci = np.quantile(season_samples, q=[0.10, 0.90], axis=1)
season_50ci = np.quantile(season_samples, q=[0.25, 0.75], axis=1)
## 描画処理
# 描画領域の設定
plt.figure(figsize=(8, 4))
# μの中央値の描画
plt.plot(pred_times, season_median, color='tab:red', label='中央値')
# seasonの50%CIの描画
plt.fill_between(pred_times, season_50ci[0], season_50ci[1], color='tomato',
alpha=0.5, label='50%CI')
# seasonの80%CIの描画
plt.fill_between(pred_times, season_80ci[0], season_80ci[1], color='tomato',
alpha=0.2, label='80%CI')
# 観測値の最後の期の垂直線の描画
plt.axvline(len(data), color='black', lw=0.8, ls='--')
# 修飾
plt.xlabel('Time(Quarter)')
plt.ylabel('Y')
plt.title(r'季節調整項の推論値')
plt.legend(bbox_to_anchor=(1, 1))
plt.grid(lw=0.5);
【実行結果】
予測期間の周期の振れ幅が大きくなったような印象です。
第12章 練習問題は以上です。
全写経を完了しました。
ご清聴ありがとうございました。
シリーズの記事
次の記事
前の記事
目次
ブログの紹介
note で7つのシリーズ記事を書いています。
ぜひ覗いていってくださいね!
1.のんびり統計
統計検定2級の問題集を手がかりにして、確率・統計をざっくり掘り下げるブログです。
雑談感覚で大丈夫です。ぜひ覗いていってくださいね。
統計検定2級公式問題集CBT対応版に対応しています。
Python、EXCELのサンプルコードの配布もあります。
2.実験!たのしいベイズモデリング1&2をPyMC Ver.5で
書籍「たのしいベイズモデリング」・「たのしいベイズモデリング2」の心理学研究に用いられたベイズモデルを PyMC Ver.5で描いて分析します。
この書籍をはじめ、多くのベイズモデルは R言語+Stanで書かれています。
PyMCの可能性を探り出し、手軽にベイズモデリングを実践できるように努めます。
身近なテーマ、イメージしやすいテーマですので、ぜひぜひPyMCで動かして、一緒に楽しみましょう!
3.実験!岩波データサイエンス1のベイズモデリングをPyMC Ver.5で
書籍「実験!岩波データサイエンスvol.1」の4人のベイジアンによるベイズモデルを PyMC Ver.5で描いて分析します。
この書籍はベイズプログラミングのイロハをざっくりと学ぶことができる良書です。
楽しくPyMCモデルを動かして、ベイズと仲良しになれた気がします。
みなさんもぜひぜひPyMCで動かして、一緒に遊んで学びましょう!
4.楽しい写経 ベイズ・Python等
ベイズ、Python、その他の「書籍の写経活動」の成果をブログにします。
主にPythonへの翻訳に取り組んでいます。
写経に取り組むお仲間さんのサンプルコードになれば幸いです🍀
5.RとStanではじめる心理学のための時系列分析入門 を PythonとPyMC Ver.5 で
書籍「RとStanではじめる心理学のための時系列分析入門」の時系列分析をPythonとPyMC Ver.5 で実践します。
この書籍には時系列分析のテーマが盛りだくさん!
時系列分析の懐の深さを実感いたしました。
大好きなPythonで楽しく時系列分析を学びます。
6.データサイエンスっぽいことを綴る
統計、データ分析、AI、機械学習、Pythonのコラムを不定期に綴っています。
統計・データサイエンス書籍にまつわる記事が多いです。
「統計」「Python」「数学とPython」「R」のシリーズが生まれています。
7.Python機械学習プログラミング実践記
書籍「Python機械学習プログラミング PyTorch & scikit-learn編」を学んだときのさまざまな思いを記事にしました。
この書籍は、scikit-learnとPyTorchの教科書です。
よかったらぜひ、お試しくださいませ。
最後までお読みいただきまして、ありがとうございました。
この記事が気に入ったらサポートをしてみませんか?