見出し画像

StanとRでベイズ統計モデリングをPyMC Ver.5で写経~第12章「12.1 状態空間モデルことはじめ」

第12章「時間や空間を扱うモデル」

書籍の著者 松浦健太郎 先生


この記事は、テキスト第12章「時間や空間を扱うモデル」の 12.1節「状態空間モデルことはじめ」の PyMC5写経 を取り扱います。
PyMCの GaussianRandomWalk() と AR() で1階差分、2階差分を表現します。

はじめに


StanとRでベイズ統計モデリングの紹介

この記事は書籍「StanとRでベイズ統計モデリング」(共立出版、「テキスト」と呼びます)のベイズモデルを用いて、PyMC Ver.5で「実験的」に写経する翻訳的ドキュメンタリーです。

テキストは、2016年10月に発売され、ベイズモデリングのモデル式とプログラミングに関する丁寧な解説とモデリングの改善ポイントを網羅するチュートリアル「実践解説書」です。もちろん素晴らしいです!
アヒル本」の愛称で多くのベイジアンに愛されてきた書籍です!

テキストに従ってStanとRで実践する予定でしたが、RのStan環境を整えることができませんでした(泣)
そこでこのシリーズは、テキストのベイズモデルをPyMC Ver.5に書き換えて実践します。

引用表記

この記事は、出典に記載の書籍に掲載された文章及びコードを引用し、適宜、掲載文章とコードを改変して書いています。
【出典】
「StanとRでベイズ統計モデリング」初版第13刷、著者 松浦健太郎、共立出版

記事中のイラストは、「かわいいフリー素材集いらすとや」さんのイラストをお借りしています。
ありがとうございます!

PyMC環境の準備

Anacondaを用いる環境構築とGoogle ColaboratoryでPyMCを動かす方法について、次の記事にまとめています。
「PyMCを動かすまでの準備」章をご覧ください。


12.1 状態空間モデルことはじめ


モデリングの準備

インポート

### インポート

# 数値・確率計算
import pandas as pd
import numpy as np

# PyMC
import pymc as pm
import arviz as az

# 自己相関
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf

# 描画
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'Meiryo'

# ワーニング表示の抑制
import warnings
warnings.simplefilter('ignore')

データの読み込みと確認
サンプルコードのデータを読み込みます。

### データの読み込み ◆data-ss1.txt
# X:日付, Y:イベントの来場人数[千人]

data1 = pd.read_csv('./data/data-ss1.txt')
print('data1.shape: ', data1.shape)
display(data1.head())

【実行結果】

テキスト図12.1右の時系列折れ線グラフを描画します。

# 時系列折れ線グラフの描画 ◆図12.1右
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(data1.X, data1.Y, '-o');
ax.set(xticks=range(0, 21, 5), xlabel='Time(Day)', ylabel='Y')
ax.grid(lw=0.5);

【実行結果】

自己相関、偏自己相関を描画します。

### 自己相関、偏自己相関のプロット
fig, ax = plt.subplots(1, 2, figsize=(10, 3))
plot_acf(data1.Y, ax=ax[0])
plot_pacf(data1.Y, ax=ax[1])
plt.tight_layout();

【実行結果】
Yにはラグ1、ラグ4に自己相関が見られます。

12.1.5 Stanで実装

1階差分のトレンド項のモデルです。
テキストの数式をお借りします。

$$
\begin{align*}
\mu[t] - \mu[t-1] &= \varepsilon_{\mu}[t],\ \varepsilon_{\mu}[t] \sim \text{Normal}\ (0,\ \sigma_{\mu}) \\
\mu[t] &= \mu[t-1] + \varepsilon_{\mu}[t],\ \varepsilon_{\mu}[t] \sim \text{Normal}\ (0,\ \sigma_{\mu}) \\
 \\
\mu[t] &\sim \text{Normal}\ (\mu[t-1],\ \sigma_{\mu}) \\
y[t] &\sim \text{Normal}\ (\mu[t],\ \sigma_Y)
\end{align*}
$$

テキストより引用

PyMCのモデル定義

PyMCでモデル式12-2を実装します。
$${\mu}$$の従う確率分布には GaussianRandomWalk() を用いました。

モデルの定義です。 

### モデルの定義 ◆モデル式12-2 model12-2.stan
# ※PyMCのGaussianRandomWalkでμの確率分布を表現

with pm.Model() as model1:
    
    ### データ関連定義
    ## coordの定義
    model1.add_coord('data', values=data1.index, mutable=True)
    
    ## dataの定義
    # 目的変数 Y
    Y = pm.ConstantData('Y', value=data1['Y'].values, dims='data')

    ### 事前分布
    # 標準偏差
    sigmaMu = pm.Uniform('sigmaMu', lower=0, upper=5)
    sigmaY = pm.Uniform('sigmaY', lower=0, upper=5)
    # mu:トレンド項
    init_dist = pm.Normal.dist(mu=data1['Y'].mean(), sigma=sigmaMu)
    mu = pm.GaussianRandomWalk('mu', mu=0, sigma=sigmaMu, init_dist=init_dist,
                               dims='data')

    ### 尤度関数
    obs = pm.Normal('obs', mu=mu, sigma=sigmaY, observed=Y, dims='data')

モデルの定義内容を見ます。

### モデルの表示
model1

【実行結果】

### モデルの可視化
pm.model_to_graphviz(model1)

【実行結果】

MCMCの実行と収束確認

MCMCを実行します。

### 事後分布からのサンプリング 10秒
with model1:
    idata1 = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.8,
                       nuts_sampler='numpyro', random_seed=1234)

【実行結果】省略

Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。

### r_hat>1.1の確認
# 設定
idata_in = idata1         # idata名
threshold = 1.04          # しきい値

# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())

【実行結果】
収束条件を満たしています。

事後統計量を表示します。

### 推論データの要約統計情報の表示
pm.summary(idata1, hdi_prob=0.95, round_to=3)

【実行結果】

トレースプロットを描画します。

### トレースプロットの表示
pm.plot_trace(idata1, compact=True)
plt.tight_layout();

【実行結果】

12.1.6 推論結果の解釈

事後統計量

事後分布の要約統計量を算出します。
算出関数を定義します。

### median, 10%, 90%パーセンタイル点をデータフレーム化する関数の定義
def make_stats_df(y):
    probs = [50, 10, 90]
    columns = ['median', '10%CI', '90%CI']
    quantiles = pd.DataFrame(np.percentile(y, probs, axis=0).T, index=y.columns)
    quantiles.columns = columns
    return quantiles

$${\sigma_{\mu}}$$と$${\sigma_Y}$$の要約統計量を算出します。

### 要約統計量の算出・表示
vars = ['sigmaMu', 'sigmaY']
param_samples = idata1.posterior[vars].to_dataframe().reset_index(drop=True)
display(make_stats_df(param_samples).round(3))

【実行結果】
テキストの結果(233ページ)と比べると、中央値はまあまあテキストに近く、一方で80%ベイズ信頼区間の幅はテキストよりも狭くなりました。

将来予測

ベイズ先輩の kenken さんから教わった予測方法を使ってみます!
次のサイトで紹介されています。

正しく使えているか不安ですが、やってみました。

### 予測モデルの定義と予測の実行

## 設定:予測期間
pred_period = 3

## モデルの定義
with pm.Model() as model1_f:
    # 標準偏差
    sigmaMu = pm.Flat('sigmaMu')
    sigmaY = pm.Flat('sigmaY')

    # mu
    init_dist = pm.DiracDelta.dist(data1['Y'].values[-1])
    mu = pm.GaussianRandomWalk('mu', mu=0, sigma=sigmaMu, init_dist=init_dist,
                               steps=pred_period)
    # y
    yPred = pm.Normal('yPred', mu=mu, sigma=sigmaY)
    
    # 予測
    idata1.extend(pm.sample_posterior_predictive(
        idata1, var_names=['mu', 'yPred'], predictions=True, random_seed=1234))

【実行結果】

$${\mu}$$の予測結果を可視化します。
テキスト図12.3左に相当します。

### μの事後分布・予測分布の描画 ◆図12.3左 ※予測期間の信用区間がテキストと異なる

## 描画用データの作成
# 推論データからμ(観測値に対する)のMCMCサンプルデータを取り出し
mu_obs_samples = idata1.posterior.mu.stack(sample=('chain', 'draw')).data
# 予測データからμ(予測値)のMCMCサンプルデータを取り出し
mu_pred_samples = idata1.predictions.mu.stack(sample=('chain', 'draw')).data
# μの結合
mu_samples = np.concatenate([mu_obs_samples, mu_pred_samples[1:]], axis=0)
# μの中央値の算出
mu_median = np.median(mu_samples, axis=1)
# μの80%CI, 50%CIの算出
mu_80ci = np.quantile(mu_samples, q=[0.10, 0.90], axis=1)
mu_50ci = np.quantile(mu_samples, q=[0.25, 0.75], axis=1)
# 予測期間を含めたx軸の値の算出
pred_times = range(1, len(data1) + pred_period + 1)

## 描画処理
# 描画領域の設定
plt.figure(figsize=(8, 4))
# Yの観測値の描画
plt.plot(data1.X, data1.Y, 'o', color='blue')
# μの中央値の描画
plt.plot(pred_times, mu_median, color='tomato')
# μの80%CIの描画
plt.fill_between(pred_times, mu_80ci[0], mu_80ci[1], color='tomato', alpha=0.2)
# μの50%CIの描画
plt.fill_between(pred_times, mu_50ci[0], mu_50ci[1], color='tomato', alpha=0.4)
# 21期の垂直線の描画
plt.axvline(21, color='black', lw=0.8, ls='--')
# 修飾
plt.xlabel('Time(Day)')
plt.ylabel('Y')
plt.grid(lw=0.5);

【実行結果】
Yの観測値が青い点、$${\mu}$$の予測値に関しては、赤い線が中央値、濃い赤塗りが50%信用区間、薄い赤塗が80%信用区間です。
テキストの結果に近い予測ができたように思います。

12.1.7 状態の変化をなめらかにする

2階差分のトレンド項のモデルです。
テキストの数式をお借りします。

$$
\begin{align*}
\mu[t] - \mu[t-1] &= \mu[t-1] - \mu[t-2] + \varepsilon_{\mu}[t-2],\ \varepsilon_{\mu}[t] \sim \text{Normal}\ (0,\ \sigma_{\mu}) \\
\mu[t] &= 2\mu[t-1] - \mu[t-2] + \varepsilon_{\mu}[t-2],\ \varepsilon_{\mu}[t] \sim \text{Normal}\ (0,\ \sigma_{\mu}) \\
 \\
\mu[t] &\sim \text{Normal}\ (\mu[t-1],\ \sigma_{\mu}) \\
y[t] &\sim \text{Normal}\ (\mu[t],\ \sigma_Y) \\
\end{align*}
$$

テキストより引用

PyMCのモデル定義

PyMCでモデル式12-4を実装します。
$${\mu}$$の従う確率分布には AR() を用いました。
自己回帰係数 rhoに [2, -1]を設定します。
モデル式の$${\mu}$$の左辺$${2\mu[t-1] - \mu[t-2]}$$に注目します。
$${\mu[t-1],\ \mu[t-2]}$$の自己回帰係数$${2,\ -1}$$を PyMC の AR() の引数 rho に設定します。

モデルの定義です。 

### モデルの定義 ◆モデル式12-4 model12-.stan
# ※PyMCのARでμの確率分布を表現

## 設定:ARの自己相関係数ρ
rho = [2, -1]

with pm.Model() as model2:
    
    ### データ関連定義
    ## coordの定義
    model2.add_coord('data', values=data1.index, mutable=True)
    
    ## dataの定義
    # 目的変数 Y
    Y = pm.ConstantData('Y', value=data1['Y'].values, dims='data')

    ### 事前分布
    # 標準偏差
    sigmaMu = pm.Uniform('sigmaMu', lower=0, upper=5)
    sigmaY = pm.Uniform('sigmaY', lower=0, upper=5)
    # mu:トレンド項
    init_dist = pm.Normal.dist(mu=data1['Y'].mean(), sigma=sigmaMu)
    mu = pm.AR('mu', rho=rho, sigma=sigmaMu, ar_order=len(rho),
               init_dist=init_dist, dims='data')

    ### 尤度関数
    obs = pm.Normal('obs', mu=mu, sigma=sigmaY, observed=Y, dims='data')

モデルの定義内容を見ます。

### モデルの表示
model2

【実行結果】

### モデルの可視化
pm.model_to_graphviz(model2)

【実行結果】

MCMCの実行と収束確認

MCMCを実行します。

### 事後分布からのサンプリング 10秒
with model2:
    idata2 = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.85,
                       nuts_sampler='numpyro', random_seed=1234)

【実行結果】省略

Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。

### r_hat>1.1の確認
# 設定
idata_in = idata2        # idata名
threshold = 1.02         # しきい値

# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())

【実行結果】
収束条件を満たしています。

事後統計量を表示します。

### 推論データの要約統計情報の表示
pm.summary(idata2, hdi_prob=0.95, round_to=3)

【実行結果】

トレースプロットを描画します。

### トレースプロットの表示
pm.plot_trace(idata2, compact=True)
plt.tight_layout();

【実行結果】

事後統計量

$${\sigma_{\mu}}$$と$${\sigma_Y}$$の要約統計量を算出します。

### 要約統計量の算出・表示
vars = ['sigmaMu', 'sigmaY']
param_samples = idata2.posterior[vars].to_dataframe().reset_index(drop=True)
display(make_stats_df(param_samples).round(3))

【実行結果】
テキストに推論結果が掲載されていないので、このPyMCモデルの推論の適否は不明です。

将来予測

1次差分のモデルと同じ方法で将来予測をしてみます。
正しく使えているかかなり不安ですが、やってみました。

### モデルの定義 予測モデル

## 設定:予測期間
pred_period = 3

## モデルの定義
with pm.Model() as model2_f:
    # 標準偏差
    sigmaMu = pm.Flat('sigmaMu')
    sigmaY = pm.Flat('sigmaY')

    # mu
    init_dist = pm.DiracDelta.dist(data1['Y'].values[-1])
    mu = pm.AR('mu', rho=rho, sigma=sigmaMu, ar_order=len(rho),
               init_dist=init_dist, steps=pred_period)
    # y
    yPred = pm.Normal('yPred', mu=mu, sigma=sigmaY)
    
    # 予測
    idata2.extend(pm.sample_posterior_predictive(
        idata2, var_names=['mu', 'yPred'], predictions=True, random_seed=1234))

【実行結果】

$${\mu}$$の予測結果を可視化します。
テキスト図12.3右に相当します。

### μの事後分布・予測分布の描画 ◆図12.4右 ※予測期間の信用区間がテキストと異なる

## 描画用データの作成
# 推論データからμ(観測値に対する)のMCMCサンプルデータを取り出し
mu_obs_samples = idata2.posterior.mu.stack(sample=('chain', 'draw')).data
# 予測データからμ(予測値)のMCMCサンプルデータを取り出し
mu_pred_samples = idata2.predictions.mu.stack(sample=('chain', 'draw')).data
# μの結合
mu_samples = np.concatenate([mu_obs_samples, mu_pred_samples[2:]], axis=0)
# μの中央値の算出
mu_median = np.median(mu_samples, axis=1)
# μの80%CI, 50%CIの算出
mu_80ci = np.quantile(mu_samples, q=[0.10, 0.90], axis=1)
mu_50ci = np.quantile(mu_samples, q=[0.25, 0.75], axis=1)
# 予測期間を含めたx軸の値の算出
pred_times = range(1, len(data1) + pred_period + 1)

## 描画処理
# 描画領域の設定
plt.figure(figsize=(8, 4))
# Yの観測値の描画
plt.plot(data1.X, data1.Y, 'o', color='blue')
# μの中央値の描画
plt.plot(pred_times, mu_median, color='tomato')
# μの80%CIの描画
plt.fill_between(pred_times, mu_80ci[0], mu_80ci[1], color='tomato', alpha=0.2)
# μの50%CIの描画
plt.fill_between(pred_times, mu_50ci[0], mu_50ci[1], color='tomato', alpha=0.4)
# 21期の垂直線の描画
plt.axvline(21, color='black', lw=0.8, ls='--')
# 修飾
plt.xlabel('Time(Day)')
plt.ylabel('Y')
plt.grid(lw=0.5);

【実行結果】
予測区間がテキストと全く異なる結果になっています。
観測値のある区間についても、テキストの滑らかさには及ばない結果になりました。

将来予測2

別の方法で将来予測を行ってみます。
将来期間のYの値を欠損値にしてモデルに与えて、将来期間の予測を含むMCMCを実行します。

モデルの定義です。

### モデルの定義 ◆モデル式12-4 model12-.stan

## 設定
# ARの自己相関係数ρ
rho = [2, -1]
# 予測期間
pred_period = 3

## Yの観測値に予測期間のNaNを追加
y_forecast = np.concatenate([data1['Y'].values, np.repeat(np.nan, pred_period)])

## モデルの定義
with pm.Model() as model2_mod:
    
    ### データ関連定義
    ## coordの定義
    model2_mod.add_coord('data', values=range(len(y_forecast)), mutable=True)
    
    ### 事前分布
    # 標準偏差
    sigmaMu = pm.Uniform('sigmaMu', lower=0, upper=5)
    sigmaY = pm.Uniform('sigmaY', lower=0, upper=5)
    # mu:トレンド項
    init_dist = pm.Normal.dist(mu=data1['Y'].mean(), sigma=sigmaMu)
    mu = pm.AR('mu', rho=rho, sigma=sigmaMu, ar_order=len(rho),
               init_dist=init_dist, dims='data')

    ### 尤度関数
    obs = pm.Normal('obs', mu=mu, sigma=sigmaY, observed=y_forecast, dims='data')

モデルの定義内容を見ます。

### モデルの表示
model2_mod.basic_RVs

【実行結果】

### モデルの可視化
pm.model_to_graphviz(model2_mod)

【実行結果】
左下の obs_unobserved が Y の予測期間に相当します。

MCMCを実行します。

### 事後分布からのサンプリング 1分40秒
with model2_mod:
    idata2_mod = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.8,
                           random_seed=1234)

【実行結果】

Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。

### r_hat>1.1の確認
# 設定
idata_in = idata2_mod    # idata名
threshold = 1.01         # しきい値

# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())

【実行結果】
収束条件を満たしています。

事後統計量を表示します。

### 推論データの要約統計情報の表示
pm.summary(idata2_mod, hdi_prob=0.95, round_to=3)

【実行結果】

トレースプロットを描画します。

### トレースプロットの表示
pm.plot_trace(idata2_mod, compact=True)
plt.tight_layout();

【実行結果】
かなり発散を含んでいるようですが、いったん見なかったことにします。

$${\sigma_{\mu}}$$と$${\sigma_Y}$$の要約統計量を算出します。

### 要約統計量の算出・表示
vars = ['sigmaMu', 'sigmaY']
param_samples = idata2_mod.posterior[vars].to_dataframe().reset_index(drop=True)
display(make_stats_df(param_samples).round(3))

【実行結果】
1つ目の事後統計量と比べると少し変動した感じです。

(参考:1つ目の2次差分項のモデルの事後統計量)

$${\mu}$$の予測結果を可視化します。
テキスト図12.3右に相当します。

### μの事後分布・予測分布の描画 ◆図12.4右

## 描画用データの作成
# 推論データからμ(観測値に対する)のMCMCサンプルデータを取り出し
mu_samples = idata2_mod.posterior.mu.stack(sample=('chain', 'draw')).data
# μの中央値の算出
mu_median = np.median(mu_samples, axis=1)
# μの80%CI, 50%CIの算出
mu_80ci = np.quantile(mu_samples, q=[0.10, 0.90], axis=1)
mu_50ci = np.quantile(mu_samples, q=[0.25, 0.75], axis=1)
# 予測期間を含めたx軸の値の算出
pred_times = range(1, len(data1) + pred_period + 1)

## 描画処理
# 描画領域の設定
plt.figure(figsize=(8, 4))
# Yの観測値の描画
plt.plot(data1.X, data1.Y, 'o', color='blue')
# μの中央値の描画
plt.plot(pred_times, mu_median, color='tomato')
# μの80%CIの描画
plt.fill_between(pred_times, mu_80ci[0], mu_80ci[1], color='tomato', alpha=0.2)
# μの50%CIの描画
plt.fill_between(pred_times, mu_50ci[0], mu_50ci[1], color='tomato', alpha=0.4)
# 21期の垂直線の描画
plt.axvline(21, color='black', lw=0.8, ls='--')
# 修飾
plt.xlabel('Time(Day)')
plt.ylabel('Y')
plt.grid(lw=0.5);

【実行結果】
こちらの方がテキストの結果に近くなりました。
予測の帯も滑らかになりました。

12.1 節は以上です。


シリーズの記事

次の記事

前の記事

目次


ブログの紹介


note で7つのシリーズ記事を書いています。
ぜひ覗いていってくださいね!

1.のんびり統計

統計検定2級の問題集を手がかりにして、確率・統計をざっくり掘り下げるブログです。
雑談感覚で大丈夫です。ぜひ覗いていってくださいね。
統計検定2級公式問題集CBT対応版に対応しています。
Python、EXCELのサンプルコードの配布もあります。

2.実験!たのしいベイズモデリング1&2をPyMC Ver.5で

書籍「たのしいベイズモデリング」・「たのしいベイズモデリング2」の心理学研究に用いられたベイズモデルを PyMC Ver.5で描いて分析します。
この書籍をはじめ、多くのベイズモデルは R言語+Stanで書かれています。
PyMCの可能性を探り出し、手軽にベイズモデリングを実践できるように努めます。
身近なテーマ、イメージしやすいテーマですので、ぜひぜひPyMCで動かして、一緒に楽しみましょう!

3.実験!岩波データサイエンス1のベイズモデリングをPyMC Ver.5で

書籍「実験!岩波データサイエンスvol.1」の4人のベイジアンによるベイズモデルを PyMC Ver.5で描いて分析します。
この書籍はベイズプログラミングのイロハをざっくりと学ぶことができる良書です。
楽しくPyMCモデルを動かして、ベイズと仲良しになれた気がします。
みなさんもぜひぜひPyMCで動かして、一緒に遊んで学びましょう!

4.楽しい写経 ベイズ・Python等

ベイズ、Python、その他の「書籍の写経活動」の成果をブログにします。
主にPythonへの翻訳に取り組んでいます。
写経に取り組むお仲間さんのサンプルコードになれば幸いです🍀

5.RとStanではじめる心理学のための時系列分析入門 を PythonとPyMC Ver.5 で

書籍「RとStanではじめる心理学のための時系列分析入門」の時系列分析をPythonとPyMC Ver.5 で実践します。
この書籍には時系列分析のテーマが盛りだくさん!
時系列分析の懐の深さを実感いたしました。
大好きなPythonで楽しく時系列分析を学びます。

6.データサイエンスっぽいことを綴る

統計、データ分析、AI、機械学習、Pythonのコラムを不定期に綴っています。
統計・データサイエンス書籍にまつわる記事が多いです。
「統計」「Python」「数学とPython」「R」のシリーズが生まれています。

7.Python機械学習プログラミング実践記

書籍「Python機械学習プログラミング PyTorch & scikit-learn編」を学んだときのさまざまな思いを記事にしました。
この書籍は、scikit-learnとPyTorchの教科書です。
よかったらぜひ、お試しくださいませ。

最後までお読みいただきまして、ありがとうございました。

いいなと思ったら応援しよう!

この記事が参加している募集