見出し画像

StanとRでベイズ統計モデリングをPyMC Ver.5で写経~第12章「12.3 変化点検出」

第12章「時間や空間を扱うモデル」

書籍の著者 松浦健太郎 先生


この記事は、テキスト第12章「時間や空間を扱うモデル」の 12.3節「変化点検出」の PyMC5写経 を取り扱います。
全般的にモデリングはうまくいかなかったです(汗)

はじめに


StanとRでベイズ統計モデリングの紹介

この記事は書籍「StanとRでベイズ統計モデリング」(共立出版、「テキスト」と呼びます)のベイズモデルを用いて、PyMC Ver.5で「実験的」に写経する翻訳的ドキュメンタリーです。

テキストは、2016年10月に発売され、ベイズモデリングのモデル式とプログラミングに関する丁寧な解説とモデリングの改善ポイントを網羅するチュートリアル「実践解説書」です。もちろん素晴らしいです!
アヒル本」の愛称で多くのベイジアンに愛されてきた書籍です!

テキストに従ってStanとRで実践する予定でしたが、RのStan環境を整えることができませんでした(泣)
そこでこのシリーズは、テキストのベイズモデルをPyMC Ver.5に書き換えて実践します。

引用表記

この記事は、出典に記載の書籍に掲載された文章及びコードを引用し、適宜、掲載文章とコードを改変して書いています。
【出典】
「StanとRでベイズ統計モデリング」初版第13刷、著者 松浦健太郎、共立出版

記事中のイラストは、「かわいいフリー素材集いらすとや」さんのイラストをお借りしています。
ありがとうございます!

PyMC環境の準備

Anacondaを用いる環境構築とGoogle ColaboratoryでPyMCを動かす方法について、次の記事にまとめています。
「PyMCを動かすまでの準備」章をご覧ください。


12.3 変化点検出


モデリングの準備

インポート

### インポート

# 数値・確率計算
import pandas as pd
import numpy as np

# PyMC
import pymc as pm
import pytensor.tensor as pt
import arviz as az

# 描画
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'Meiryo'

# ワーニング表示の抑制
import warnings
warnings.simplefilter('ignore')

データの読み込みと確認
サンプルコードのデータを読み込みます。

### データの読み込み ◆data-changepoint.txt
# X:日付(四半期), Y:季節ものの販売数[千個]

data = pd.read_csv('./data/data-changepoint.txt')
print('data.shape: ', data.shape)
display(data.head())

【実行結果】

テキスト図12.5左の時系列折れ線グラフを描画します。

# 時系列折れ線グラフの描画 ◆図12.5左
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(data.X, data.Y, lw=0.8)
ax.axvline(75, color='tab:red', ls='--')
ax.axvline(124, color='tab:red', ls='--')
ax.set(xlabel='Time(Second)', ylabel='Y')
ax.grid(lw=0.5);

【実行結果】
200期前後で大きな変化がありそうです。
また、100期前後と300期前後にも変化がありそうです。

赤い点線は予告です!
全期間を通したモデリングをうまくできなかったので、赤い点線の期間に限定して、モデリングを行いました。

PyMCのモデル定義

75期からの50期間の推論を行います。
以下のPyMCモデルで全期間を通してMCMCを実行するとエラーになってしまうからです。
他に良いモデリング方法がありましたら、ぜひ教えてください。

PyMCでモデル式12-2を実装します。
モデルの定義です。 

### モデルの定義 ◆モデル式12-7 model12-7.stan

# データを短くして実行
start = 75
length = 50
data_s = data.iloc[start: start+length]

# 初期値設定
T = len(data_s)

# モデルの定義
with pm.Model() as model:
    
    ### データ関連定義
    # coordの定義
    model.add_coord('data', values=data_s.index, mutable=True)
    # dataの定義
    Y = pm.ConstantData('Y', value=data_s['Y'].values, dims='data')
    
    ### 事前分布
    sigmaY = pm.Uniform('sigmaY', lower=0, upper=10)
    sigmaMu = pm.Uniform('sigmaMu', lower=0, upper=10)
    muZero = pm.Uniform('muZero', lower=-np.pi, upper=np.pi)
    muRaw = pm.Uniform('muRaw', lower=-np.pi/2, upper=np.pi/2, shape=T-1)

    # mu mu0~muTの変数を1つづつ作成
    mu = [0] * T
    mu[0] = pm.Deterministic('mu0', muZero)
    for t in range(1, T):
        mu[t] = pm.Deterministic('mu'+str(t),
                                 mu[t-1] + sigmaMu * pt.tan(muRaw[t-1]))
    
    ### 尤度
    obs = pm.Normal('obs', mu=mu, sigma=sigmaY, observed=Y, dims='data')

モデルの定義内容を見ます。

### モデルの表示
model.basic_RVs

【実行結果】

MCMCの実行と収束確認

MCMCを実行します。

### 事後分布からのサンプリング 50件:2分50秒 ※件数が多い場合エラーが発生する
# UnicodeDecodeError: 'utf-8' codec can't decode byte 0x93 in position 9: invalid start byte

with model:
    idata = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.8,
                      nuts_sampler='numpyro', random_seed=1234)

【実行結果】省略

Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。

### r_hat>1.1の確認
# 設定
idata_in = idata          # idata名
threshold = 1.01          # しきい値

# しきい値を超えるR_hatの個数を表示
display((az.rhat(idata_in) > threshold).sum())

【実行結果】
収束条件を満たしています。

事後統計量を表示します。

### 推論データの要約統計情報の表示
pm.summary(idata, hdi_prob=0.95, round_to=3)

【実行結果】

トレースプロットを描画します。

### トレースプロットの表示
pm.plot_trace(idata, compact=True)
plt.tight_layout();

【実行結果】
発散を示す黒いバーコードが出ています。
いったん目を瞑ります。

推論結果の解釈

事後統計量

事後分布の要約統計量を算出します。
算出関数を定義します。

### median, 10%, 90%パーセンタイル点をデータフレーム化する関数の定義
def make_stats_df(y):
    probs = [50, 10, 90]
    columns = ['median', '10%CI', '90%CI']
    quantiles = pd.DataFrame(np.percentile(y, probs, axis=0).T, index=y.columns)
    quantiles.columns = columns
    return quantiles

要約統計量を算出します。

### 要約統計量の算出・表示
vars = ['sigmaMu', 'sigmaY']
param_samples = idata.posterior[vars].to_dataframe().reset_index(drop=True)
display(make_stats_df(param_samples).round(3))

【実行結果】
テキストに事後統計量の記載が無いため、PyMCモデルの推論結果の妥当性は分かりません。

モデルの推定結果の可視化

テキスト図12.5右の一部分に相当するグラフを描画します。

### μの描画 ◆図12.4右の一部

## 描画用データの作成
# 推論データからμのMCMCサンプルデータを取り出し
mu_samples = np.array(
    [idata.posterior['mu'+str(i)].stack(sample=('chain', 'draw')).data
     for i in range(length)])
# μの中央値の算出
mu_median = np.median(mu_samples, axis=1)
# μの80%CI, 50%CIの算出
mu_80ci = np.quantile(mu_samples, q=[0.10, 0.90], axis=1)
mu_50ci = np.quantile(mu_samples, q=[0.25, 0.75], axis=1)

## 描画処理
# 描画領域の設定
plt.figure(figsize=(8, 4))
# Yの観測値の描画
plt.plot(range(start, start+length), data.Y.values[start:start+length],
         '-o', color='tab:blue', label='$Y$ の観測値')
# μの中央値の描画
plt.plot(range(start, start+length), mu_median, color='tab:red',
               label='$\mu$:中央値')
# μの50%CIの描画
plt.fill_between(range(start, start+length), mu_50ci[0], mu_50ci[1],
                 color='tomato', alpha=0.5, label='$\mu$:50%CI')
# μの80%CIの描画
plt.fill_between(range(start, start+length), mu_80ci[0], mu_80ci[1],
                 color='tomato', alpha=0.2, label='$\mu$:80%CI')
# 修飾
plt.xlabel('Time(Second)')
plt.ylabel('Y')
plt.title(r'$Y$ の観測値と $\mu$ の推論値')
plt.legend(bbox_to_anchor=(1, 1))
plt.grid(lw=0.5);

【実行結果】
100期前後の変化点を捉えられている感じがします。

アディショナルタイム1

変化点前後で平均的に推移するモデルで推論してみます。
まずは変化点が1つのケースです。

PyMCのモデル定義

with pm.Model() as model2:
    
    # 変化点前の平均の数
    n = 2

    ### データ関連定義
    # coordの定義
    model2.add_coord('data', values=data.index, mutable=True)
    # dataの定義
    Y = pm.ConstantData('Y', value=data['Y'].values, dims='data')
    timeIdx = pm.ConstantData('timeIdx', value=range(data['X'].max()), dims='data')

    ### 事前分布
    # 変化点 cp1~cp3
    cp1 = pm.Uniform('cp1', lower=1, upper=len(data))
    # mu
    mus = pm.Uniform('mus', lower=-3, upper=3, shape=n)
    mu = pm.Deterministic('mu',
                          pt.switch(pt.le(timeIdx, cp1), mus[0], mus[1]),
                          dims='data')
    
    ### 尤度:観測モデル
    sigmaY = pm.Uniform('sigmaY', lower=0, upper=3)
    obs = pm.Normal('obs', mu=mu, sigma=sigmaY, observed=Y, dims='data')

モデルの定義内容を見ます。

### モデルの表示
model2

【実行結果】

### モデルの可視化
pm.model_to_graphviz(model2)

【実行結果】

MCMCの実行と収束確認

MCMCを実行します。

### 事後分布からのサンプリング 20秒

with model2:
    idata2 = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.8,
                       nuts_sampler='numpyro', random_seed=1234)

【実行結果】省略

Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。

### r_hat>1.1の確認
# 設定
idata_in = idata2         # idata名
threshold = 1.08          # しきい値

# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())

【実行結果】
収束条件を満たしています。

事後統計量を表示します。

### 推論データの要約統計情報の表示
var_names = ['cp1', 'sigmaY', 'mu']
pm.summary(idata2, hdi_prob=0.95, var_names=var_names, round_to=3)

【実行結果】

トレースプロットを描画します。

### トレースプロットの表示
pm.plot_trace(idata2, compact=True, var_names=var_names)
plt.tight_layout();

【実行結果】

推論結果の解釈

モデルの推定結果の可視化

テキスト図12.5右に準拠してグラフを描画します。

### μの描画 ◆図12.4右に準拠

## 描画用データの作成
# 推論データからμのMCMCサンプルデータを取り出し
mu_samples = idata2.posterior.mu.stack(sample=('chain', 'draw')).data
# μの中央値の算出
mu_median = np.median(mu_samples, axis=1)
# μの95%CI, 50%CIの算出
mu_95ci = np.quantile(mu_samples, q=[0.025, 0.975], axis=1)
mu_50ci = np.quantile(mu_samples, q=[0.250, 0.750], axis=1)

## 描画処理
# 描画領域の設定
plt.figure(figsize=(8, 4))
# Yの観測値の描画
plt.plot(data.X, data.Y, lw=0.8, color='tab:blue', label='$Y$ の観測値')
# μの中央値の描画
plt.plot(data.X, mu_median, color='tab:red', label='$\mu$:中央値')
# μの50%CIの描画
plt.fill_between(data.X, mu_50ci[0], mu_50ci[1],
                 color='tomato', alpha=0.5, label='$\mu$:50%CI')
# μの95%CIの描画
plt.fill_between(data.X, mu_95ci[0], mu_95ci[1],
                 color='tomato', alpha=0.2, label='$\mu$:95%CI')
# 修飾
plt.xlabel('Time(Second)')
plt.ylabel('Y')
plt.title(r'$Y$ の観測値と $\mu$ の推論値')
plt.legend(bbox_to_anchor=(1, 1))
plt.grid(lw=0.5);

【実行結果】
変化点を1つに限定する場合、200期前後の変化を捉えるようです。

アディショナルタイム2

つぎは変化点が3つのケースです。

PyMCのモデル定義

with pm.Model() as model3:
    
    # 変化点前の平均の数
    n = 4

    ### データ関連定義
    # coordの定義
    model3.add_coord('data', values=data.index, mutable=True)
    # dataの定義
    Y = pm.ConstantData('Y', value=data['Y'].values, dims='data')
    timeIdx = pm.ConstantData('timeIdx', value=range(data['X'].max()), dims='data')

    ### 事前分布
    # 変化点 cp1~cp3
    cpDiff = pm.Uniform('cpDiff', lower=0, upper=200, shape=2)
    cp1 = pm.Uniform('cp1', lower=1, upper=200)
    cp2 = pm.Deterministic('cp2', cp1 + cpDiff[0])
    cp3 = pm.Deterministic('cp3', cp2 + cpDiff[1])
    # mu
    mus = pm.Uniform('mus', lower=-3, upper=3, shape=n)
    mu = pm.Deterministic('mu',
                          pt.switch(pt.le(timeIdx, cp1), mus[0],
                          pt.switch(pt.le(timeIdx, cp2), mus[1], 
                          pt.switch(pt.le(timeIdx, cp3), mus[2], mus[3]))),
                          dims='data')
    
    ### 尤度:観測モデル
    sigmaY = pm.Uniform('sigmaY', lower=0, upper=3)
    obs = pm.Normal('obs', mu=mu, sigma=sigmaY, observed=Y, dims='data')

モデルの定義内容を見ます。

### モデルの表示
model2

【実行結果】

### モデルの可視化
pm.model_to_graphviz(model2)

【実行結果】

MCMCの実行と収束確認

MCMCを実行します。

### 事後分布からのサンプリング 30秒

with model3:
    idata3 = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.8,
                       nuts_sampler='numpyro', random_seed=1234)

【実行結果】省略

Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。

### r_hat>1.1の確認
# 設定
idata_in = idata3         # idata名
threshold = 1.1           # しきい値

# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())

【実行結果】
収束条件を満たしていません・・・

事後統計量を表示します。

### 推論データの要約統計情報の表示
var_names = ['cp1', 'cp2', 'cp3', 'sigmaY', 'mu']
pm.summary(idata3, hdi_prob=0.95, var_names=var_names, round_to=3)

【実行結果】

トレースプロットを描画します。

### トレースプロットの表示
pm.plot_trace(idata3, compact=True, var_names=var_names)
plt.tight_layout();

【実行結果】
確かに収束している感じがいたしません。

推論結果の解釈

モデルの推定結果の可視化

収束していませんので、以降は分析には用いることが出来ず、単にコード実行例として捉えてください。
テキスト図12.5右に準拠してグラフを描画します。

### μの描画 ◆図12.4右の一部

## 描画用データの作成
# 推論データからμのMCMCサンプルデータを取り出し
mu_samples = idata3.posterior.mu.stack(sample=('chain', 'draw')).data
# μの中央値の算出
mu_median = np.median(mu_samples, axis=1)
# μの95%CI, 50%CIの算出
mu_95ci = np.quantile(mu_samples, q=[0.025, 0.975], axis=1)
mu_50ci = np.quantile(mu_samples, q=[0.250, 0.750], axis=1)

## 描画処理
# 描画領域の設定
plt.figure(figsize=(8, 4))
# Yの観測値の描画
plt.plot(data.X, data.Y, lw=0.8, color='tab:blue', label='$Y$ の観測値')
# μの中央値の描画
plt.plot(data.X, mu_median, color='tab:red', label='$\mu$:中央値')
# μの50%CIの描画
plt.fill_between(data.X, mu_50ci[0], mu_50ci[1],
                 color='tomato', alpha=0.5, label='$\mu$:50%CI')
# μの95%CIの描画
plt.fill_between(data.X, mu_95ci[0], mu_95ci[1],
                 color='tomato', alpha=0.2, label='$\mu$:95%CI')
# 修飾
plt.xlabel('Time(Second)')
plt.ylabel('Y')
plt.title(r'$Y$ の観測値と $\mu$ の推論値')
plt.legend(bbox_to_anchor=(1, 1))
plt.grid(lw=0.5);

【実行結果】
100期付近、200期付近、300期付近で変化を捉えたような図になりました。
100期付近と300期付近では、95%CIが広くなっています。

【募集】

テキストのモデル式12-7の変化点検出について、PyMCでモデリングするコードを教えてください!!!

12.3 節は以上です。


シリーズの記事

次の記事

前の記事

目次


ブログの紹介


note で7つのシリーズ記事を書いています。
ぜひ覗いていってくださいね!

1.のんびり統計

統計検定2級の問題集を手がかりにして、確率・統計をざっくり掘り下げるブログです。
雑談感覚で大丈夫です。ぜひ覗いていってくださいね。
統計検定2級公式問題集CBT対応版に対応しています。
Python、EXCELのサンプルコードの配布もあります。

2.実験!たのしいベイズモデリング1&2をPyMC Ver.5で

書籍「たのしいベイズモデリング」・「たのしいベイズモデリング2」の心理学研究に用いられたベイズモデルを PyMC Ver.5で描いて分析します。
この書籍をはじめ、多くのベイズモデルは R言語+Stanで書かれています。
PyMCの可能性を探り出し、手軽にベイズモデリングを実践できるように努めます。
身近なテーマ、イメージしやすいテーマですので、ぜひぜひPyMCで動かして、一緒に楽しみましょう!

3.実験!岩波データサイエンス1のベイズモデリングをPyMC Ver.5で

書籍「実験!岩波データサイエンスvol.1」の4人のベイジアンによるベイズモデルを PyMC Ver.5で描いて分析します。
この書籍はベイズプログラミングのイロハをざっくりと学ぶことができる良書です。
楽しくPyMCモデルを動かして、ベイズと仲良しになれた気がします。
みなさんもぜひぜひPyMCで動かして、一緒に遊んで学びましょう!

4.楽しい写経 ベイズ・Python等

ベイズ、Python、その他の「書籍の写経活動」の成果をブログにします。
主にPythonへの翻訳に取り組んでいます。
写経に取り組むお仲間さんのサンプルコードになれば幸いです🍀

5.RとStanではじめる心理学のための時系列分析入門 を PythonとPyMC Ver.5 で

書籍「RとStanではじめる心理学のための時系列分析入門」の時系列分析をPythonとPyMC Ver.5 で実践します。
この書籍には時系列分析のテーマが盛りだくさん!
時系列分析の懐の深さを実感いたしました。
大好きなPythonで楽しく時系列分析を学びます。

6.データサイエンスっぽいことを綴る

統計、データ分析、AI、機械学習、Pythonのコラムを不定期に綴っています。
統計・データサイエンス書籍にまつわる記事が多いです。
「統計」「Python」「数学とPython」「R」のシリーズが生まれています。

7.Python機械学習プログラミング実践記

書籍「Python機械学習プログラミング PyTorch & scikit-learn編」を学んだときのさまざまな思いを記事にしました。
この書籍は、scikit-learnとPyTorchの教科書です。
よかったらぜひ、お試しくださいませ。

最後までお読みいただきまして、ありがとうございました。

この記事が気に入ったらサポートをしてみませんか?