実験!岩波データサイエンス1のベイズモデリングをPyMC Ver.5で⑫状態空間モデルとベイズ決定
「Stan入門」章
テキスト「Stan入門」の執筆者
松浦健太郎 先生
この記事は、テキストの「Stan入門」章の例題2「状態空間モデルとベイズ決定」の実践を取り扱います。
状態空間モデルです。記事第7回・8回に通じます!
今回は季節調整項を取り扱います!
ベイズ事後予測を意思決定に連携します!
たのしくPyMCモデリングを進めましょう!
この章を執筆された松浦先生は「StanとRでベイズ統計モデリング (Wonderful R 2)」を執筆されています。
こちらもお楽しみくださいませ。
はじめに
岩波データサイエンスVol.1の紹介
この記事は書籍「岩波データサイエンス vol.1」(岩波書店、以下「テキスト」と呼びます)の特集記事「ベイズ推論とMCMCのフリーソフト」のベイズモデルを用いて、PyMC Ver.5で「実験的」に実装する様子を描いた統計ドキュメンタリーです。
テキストは、2015年10月に発売され、ベイズモデリングの様々なソフトウェアを用いたモデリング事例を多数掲載し、ベイズモデリングの楽しさを紹介する素晴らしい書籍です。
入門的なモデルから2次階差を取り扱う空間モデルまで、幅広い難易度のモデルを満喫できます!
このシリーズは、テキストのベイズモデルをPyMC Ver.5に書き換えて実践します。
引用表記
この記事は、出典に記載の書籍に掲載された文章及びコードを引用し、適宜、掲載文章とコードを改変して書いています。
【出典】
「岩波データサイエンス vol.1」
第9刷、編者 岩波データサイエンス刊行委員会 岩波書店
記事中のイラストは、「かわいいフリー素材集いらすとや」さんのイラストをお借りしています。
ありがとうございます!
PyMC環境の準備
Anacondaを用いる環境構築とGoogle ColaboratoryでPyMCを動かす方法について、次の記事にまとめています。
「PyMCを動かすまでの準備」章をご覧ください。
PythonとPyMC
テキストで利用するツールは、R、Stan です。
この記事では、PythonとPyMCを用いたコードに変換してベイズモデリングを実践いたします。
イントロ
1. モデルのイメージ
■ トレンド(レベル)項と季節調整項を含む状態空間モデル
季節ものの四半期ごとの予約販売数の時系列データを取り扱います。
四半期サイクル=周期4ということで季節調整を考慮します。
また、最後には状態空間モデルの推論データを用いた意思決定支援「ベイズ決定」に取り組みます。
■ モデル数式
テキストより数式を引用いたします。
1行目は観測モデルです。トレンド項$${\mu_t}$$(中身はレベル)と季節調整項$${\s_t}$$と観測ノイズで構成されています。
2行目と3行目がシステムモデルです。
2行目はトレンド項。1期前の状態$${\mu_t}$$と自己相関するランダムウォークです。
3行目は季節調整項。当期を含めた直近4四半期の$${s}$$合計がノイズと等しくなるような式になっています。これで周期4の季節性(周期性)を表現できるそうです。
2. インポート
「Stan入門」章で利用するパッケージをインポートします。
### インポート
# ユーティリティ
import pickle
# 数値・確率計算
import pandas as pd
import numpy as np
# PyMC
import pymc as pm
import pytensor.tensor as pt
import arviz as az
# 最適化
from scipy.optimize import minimize_scalar
# 描画
import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['font.family'] = 'Meiryo'
# ワーニング表示の抑制
import warnings
warnings.simplefilter('ignore')
3. データの読み込み
CSVファイルをpandasのデータフレームに読み込みます。
### データの読み込み
data2 = pd.read_csv('./data/data-season.txt')
print('data2.shape: ', data2.shape)
display(data2.head())
【実行結果】
44期の予約販売数 Y データです。
4. データの外観の確認
予約販売数の時系列プロットを描画します。
### 観測データの可視化
plt.figure(figsize=(6, 3))
ax = plt.subplot()
ax.plot(data2)
ax.set(xlabel='Time[四半期]', ylabel='販売個数[千個]', title='Yの観測値')
ax.grid();
【実行結果】
峰が4四半期ごとにできている様子が分かります。
20期あたりから上昇基調が続いているようにも見えます。
では、PyMCモデリングに進みます。
ベイズモデル
1. モデルの定義
トレンド項$${\mu_t}$$はランダムウォークなのでpm.GaussianRandomWalk() を利用します。
季節調整項$${s_t}$$は次数$${3}$$、自己回帰係数$${-1}$$のAR過程ですので、pm.AR() を利用します。
### モデルの定義 ※事前分布でパラメータの縛りを強くしている
## ARパラメータの設定
period = 4 # 季節成分の周期
order = period - 1 # ARの次数
rhos = np.ones(order) * -1 # ARのρパラメータ
## モデルの定義
with pm.Model() as model2:
### データ関連定義
## coordの定義
# 観測データのインデックス
model2.add_coord('data', values=data2.index, mutable=True)
## dataの定義
# 観測値
y = pm.MutableData('y', value=data2['Y'].values, dims='data')
### 事前分布
sigmaY = pm.HalfNormal('sigmaY', sigma=10)
sigmaMu = pm.HalfNormal('sigmaMu', sigma=10)
sigmaS = pm.HalfNormal('sigmaS', sigma=10)
### 状態空間モデル
# トレンド項:ローカルレベル・ランダムウォーク
init_dist_mu = pm.Normal.dist(mu=0, sigma=10)
mu = pm.GaussianRandomWalk('mu', mu=0, sigma=sigmaMu, init_dist=init_dist_mu,
dims='data')
# 季節調整項目:AR(3)
init_dist_s = pm.Normal.dist(mu=0, sigma=10)
s = pm.AR('s', rho=rhos, sigma=sigmaS, init_dist=init_dist_s, ar_order=order,
dims='data')
### 尤度
Y = pm.Normal('Y', mu=mu + s, sigma=sigmaY, observed=y, dims='data')
モデルの内容を表示・可視化してみましょう。
### モデルの表示
model2
【実行結果】
線形回帰モデルの切片$${a}$$と傾き$${b}$$は正規分布(平均0、標準偏差100)を仮定しています。2つのパラメータに正規分布の縛りを入れているようです。
続いてモデルを可視化します。
### モデルの可視化
pm.model_to_graphviz(model2)
【実行結果】
トレンド項と季節調整項が生成され、観測モデル$${Y}$$に渡っていく様子がよく分かります。
2. MCMCの実行と収束の確認
■ MCMC
マルコフ連鎖=chainsを4本、バーンイン期間=tuneを1000、利用するサンプル=drawを1chainあたり1000に設定して、合計4000個の事後分布からのサンプリングデータを生成します。テキストの指定数とは異なっています。
サンプル方法に numpyro を指定しています。処理速度が早くなります。
numpyro を使わない場合は「nuts_sampler='numpyro',」を削除します。
### 事後分布からのサンプリング 15秒
# テキスト:iter=10200, warmup=200, thin=10, chains=3
with model2:
idata2 = pm.sample(draws=1000, tune=1000, chains=4, random_seed=123,
nuts_sampler='numpyro', target_accept=0.95)
【実行結果】
処理時間は15秒です。
■ $${\hat{R}}$$で収束確認
収束の確認をします。
指標$${\hat{R}}$$(あーるはっと)の値が$${1.1}$$以下のとき、収束したこととします。
次のコードでは$${\hat{R}>1.05}$$のパラメータの個数をカウントします。
個数が0ならば、$${\hat{R} \leq1.1}$$なので収束したとみなします。
### r_hat>1.1の確認
# 設定
idata_in = idata2 # idata名
threshold = 1.05 # しきい値
# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())
【実行結果】
$${\hat{R}>1.05}$$のパラメータは0個でした。
$${\hat{R} \leq1.1}$$なので収束したとみなします。
■ $${\hat{R}}$$の値把握と事後統計量の表示
### 事後統計量の表示
pm.summary(idata2, hdi_prob=0.95)
【実行結果】
■ トレースプロットで収束確認
トレースプロットを描画します。
こちらの図でも収束の確認を行えます。
### トレースプロットの描画
pm.plot_trace(idata2)
plt.tight_layout();
【実行結果】
左プロットの線の重なり具合、右プロットの偏りのない乱雑さより、収束している感じがします。
グラフ下部にバーコードのような発散の印が少し見られるのが気になります。。。
このモデルで予測に進みましょう。
3. 将来期間の予測
上記の同じモデル構成にして、観測値$${Y_t}$$に予測期間8の欠損値を付け足した$${Y_p}$$を尤度関数に与えて、将来期間の販売予測数を推論します。
### 予測用モデルの定義
## 予測データの作成
# 予測期間
pred_period = 8
# Yに予測期間の未知データNaNを追加
Yp = np.concatenate([data2['Y'].values, np.repeat(np.nan, pred_period)])
## ARパラメータの設定
period = 4 # 季節成分の周期
order = period - 1 # ARの次数
rhos = np.ones(order) * -1 # ARのρパラメータ
## モデルの定義
with pm.Model() as model2p:
### データ関連定義
## coordの定義
# 観測データのインデックス
model2p.add_coord('data', values=range(len(Yp)), mutable=True)
### 事前分布
sigmaY = pm.HalfNormal('sigmaY', sigma=10)
sigmaMu = pm.HalfNormal('sigmaMu', sigma=10)
sigmaS = pm.HalfNormal('sigmaS', sigma=10)
### 状態空間モデル
# トレンド項:ローカルレベル・ランダムウォーク
init_dist_mu = pm.Normal.dist(mu=0, sigma=10)
mu = pm.GaussianRandomWalk('mu', mu=0, sigma=sigmaMu, init_dist=init_dist_mu,
dims='data')
# 季節調整項目:AR(3)
init_dist_s = pm.Normal.dist(mu=0, sigma=10)
s = pm.AR('s', rho=rhos, sigma=sigmaS, init_dist=init_dist_s, ar_order=order,
dims='data')
### 尤度
Y = pm.Normal('Y', mu=mu + s, sigma=sigmaY, observed=Yp, dims='data')
予測用モデルを可視化しましょう。
### 予測用モデルの可視化
pm.model_to_graphviz(model2p)
【実行結果】
左の色付き楕円が$${Y}$$の観測値、その隣の白楕円が$${Y}$$の予測値(=データ上は欠損値)です。
では予測=MCMCサンプリングを実践しましょう。
観測値に欠損値を含む場合はPyMC標準のサンプラーを利用します。
### 事後分布からのサンプリング 3分10秒
# テキスト:iter=10200, warmup=200, thin=10, chains=3
with model2p:
idata2p = pm.sample(draws=1000, tune=1000, chains=4, random_seed=123,
target_accept=0.95)
【実行結果】
4. 分析
テキストの図2と同様に、①観測値だけ、②将来予測値付き、③トレンド項、④季節調整項の4つを描画します。
### 予測データの可視化 ※図2に相当
## 予測データの取り出し・データ作成
# 目的変数Y
y_pred = idata2p.posterior.Y.stack(sample=('chain', 'draw')).data
y_pred_50hdi = az.hdi(y_pred.T, hdi_prob=0.50).T
y_pred_80hdi = az.hdi(y_pred.T, hdi_prob=0.80).T
# トレンド項μ
mu_pred = idata2p.posterior.mu.stack(sample=('chain', 'draw')).data
mu_pred_50hdi = az.hdi(mu_pred.T, hdi_prob=0.50).T
mu_pred_80hdi = az.hdi(mu_pred.T, hdi_prob=0.80).T
# 季節調整項s
s_pred = idata2p.posterior.s.stack(sample=('chain', 'draw')).data
s_pred_50hdi = az.hdi(s_pred.T, hdi_prob=0.50).T
s_pred_80hdi = az.hdi(s_pred.T, hdi_prob=0.80).T
## Yの観測値の描画
plt.figure(figsize=(6, 3))
ax = plt.subplot()
ax.plot(data2, color='dodgerblue', lw=5, label='観測値')
ax.axvline(len(data2)-1, color='gray', ls='--')
ax.set(xlabel='Time[四半期]', ylabel='販売個数[千個]', title='Yの観測値',
xlim=(0, len(Yp)))
ax.grid()
ax.legend();
## Yの予測値の描画
plt.figure(figsize=(6, 3))
ax = plt.subplot()
ax.plot(data2, color='dodgerblue', lw=5, label='観測値')
ax.plot(range(len(Yp)), np.median(y_pred, axis=1), color='red', label='予測値')
ax.fill_between(range(len(Yp)), y_pred_50hdi[0], y_pred_50hdi[1],
color='tomato', alpha=0.5, label='50%HDI')
ax.fill_between(range(len(Yp)), y_pred_80hdi[0], y_pred_80hdi[1],
color='tomato', alpha=0.2, label='80%HDI')
ax.axvline(len(data2)-1, color='gray', ls='--')
ax.set(xlabel='Time[四半期]', ylabel='販売個数[千個]', title='Yの予測値',
xlim=(0, len(Yp)))
ax.grid()
ax.legend();
## trendの予測値の描画
plt.figure(figsize=(6, 3))
ax = plt.subplot()
ax.plot(data2, color='dodgerblue', lw=5, label='観測値')
ax.plot(range(len(Yp)), np.median(y_pred, axis=1), color='red', label='予測値')
ax.fill_between(range(len(Yp)), mu_pred_50hdi[0], mu_pred_50hdi[1],
color='tomato', alpha=0.5, label='トレンド50%HDI')
ax.fill_between(range(len(Yp)), mu_pred_80hdi[0], mu_pred_80hdi[1],
color='tomato', alpha=0.2, label='トレンド80%HDI')
ax.axvline(len(data2)-1, color='gray', ls='--')
ax.set(xlabel='Time[四半期]', ylabel='販売個数[千個]', title='トレンド項 $\mu_t$',
xlim=(0, len(Yp)))
ax.grid()
ax.legend();
## 季節調整の予測値の描画
plt.figure(figsize=(6, 3))
ax = plt.subplot()
ax.fill_between(range(len(Yp)), s_pred_50hdi[0], s_pred_50hdi[1],
color='tomato', alpha=0.5, label='季節調整50%HDI')
ax.fill_between(range(len(Yp)), s_pred_80hdi[0], s_pred_80hdi[1],
color='tomato', alpha=0.2, label='季節調整80%HDI')
ax.axvline(len(data2)-1, color='gray', ls='--')
ax.set(xlabel='Time[四半期]', ylabel='販売個数[千個]', title='季節調整項 $s_t$',
xlim=(0, len(Yp)))
ax.grid()
ax.legend();
【実行結果】
2番目の将来予測は上昇基調にあるように見えますが、3番目のトレンド項(中身はレベルのみ)を見ると将来予測は横ばいです。
また4番目の季節調整項では周期4のサイクルが抽出されています。時間経過とともに上下幅が広がっている感じがします。
5. ベイズ決定
テキストは「1期先の予約販売個数$${x_{next}}$$を今の時点で決定して発注しなければならない状況を考える」と投げかけます。
在庫不足の場合には販売機会損失$${2(x_{next}-x)}$$が発生し、在庫余剰の場合には廃棄コスト$${1-\exp[-(x-x_{next})]}$$が発生します。
これらの「損失関数」を最小にする$${x_{next}}$$は、MCMCサンプリングデータを利用して算出できるのです!
Pythonの最適化計算は scipy.optimize の minimize_scalar を利用しましょう。
### ベイズ決定
## データ準備
# 実際の予約販売数量x_smp:Yの予測値から1期先の販売数量サンプリングデータを取得
x_smp = y_pred[44]
## 損失関数の定義:xは1期先の予約販売数量の予測値
def loss_func(x):
return np.sum(np.where(x < x_smp, 2*(x_smp - x), 1-np.exp(-(x-x_smp))))
## 最適化の実行・結果表示:scipy.optim.minimize_scalar
result = minimize_scalar(loss_func, method='brent', bracket=(5, 50))
print(f'ロス最小の発注数量: {result.x:.3f} 千個')
【実行結果】
サクッと計算できました。
テキストには計算結果が掲載されていないので、答え合わせができませぬ・・・
損失関数を可視化してみましょう。
### 最適化数量のプロット
# データ準備:x_smpを昇順並び替え
x_smp_sorted = sorted(x_smp)
# 描画領域の指定
plt.figure(figsize=(7,3))
ax = plt.subplot()
# 販売ロスの描画
ax.plot(x_smp_sorted, 2*(x_smp_sorted-result.x), label='販売ロス')
# 廃棄ロスの描画
ax.plot(x_smp_sorted, 1-np.exp(-(result.x-x_smp_sorted)), label='廃棄ロス')
# 最適化数量の垂直線の描画
ax.axvline(result.x, color='red', ls='--', label=f'最適化数量 {result.x:.3f}')
# y=0(ロス0)の描画
ax.axhline(0, color='gray', lw=0.5)
# 修飾
ax.set(xlabel='1期先の予約販売数量の予測値 x_smp [千個]', ylabel='ロス金額')
ax.legend();
【実行結果】
青線の販売機会ロスとオレンジの廃棄ロスの交差するx軸が約26千個であり、このときのロス金額は0です。
ベイズモデルによる推論結果(事後予測)を用いて、事業経営に利用可能な指標を算出して意思決定に活かせる、このサイクルはとてもありがたいですね!
ベイズと事業と経営の循環関係、またしてもベイズのエコシステムの発見です。
6. 推論データの保存
pickle で MCMCサンプリングデータ idata2 と idata2p を保存できます。
### 推論データの保存
file = r'idata2_ch4.pkl'
with open(file, 'wb') as f:
pickle.dump(idata2, f)
file = r'idata2p_ch4.pkl'
with open(file, 'wb') as f:
pickle.dump(idata2p, f)
次のコードで保存したファイルを読み込みできます。
### 推論データの読み込み
file = r'idata2_ch4.pkl'
with open(file, 'rb') as f:
idata2_ch4_load = pickle.load(f)
file = r'idata2p_ch4.pkl'
with open(file, 'rb') as f:
idata2p_ch4_load = pickle.load(f)
ベイズ記事は以上です。
次回予告
「Stan入門」章
例題3「空間構造のあるベイズモデル」
結び
時系列データの状態空間モデルで(たぶん)よく用いられるシステムモデルをまとめてみます。
1行目のレベル項はランダムウォーク、2行目のトレンド項は2次(2階)のトレンドモデル、3行目の季節調整項は周期$${S}$$の合計が0に近い値になるモデルです。
後から何度見返すことになるでしょうか・・・。
ちなみに pm.GaussianRandomWalk() を利用する場合、引数 mu にトレンド項(ドリフト項)を設定することでレベル・トレンド項を表現できる(できたはず???)ようです。
おわり
シリーズの記事
次の記事
前の記事
目次
ブログの紹介
note で7つのシリーズ記事を書いています。
ぜひ覗いていってくださいね!
1.のんびり統計
統計検定2級の問題集を手がかりにして、確率・統計をざっくり掘り下げるブログです。
雑談感覚で大丈夫です。ぜひ覗いていってくださいね。
統計検定2級公式問題集CBT対応版に対応しています。
Python、EXCELのサンプルコードの配布もあります。
2.実験!たのしいベイズモデリング1&2をPyMC Ver.5で
書籍「たのしいベイズモデリング」・「たのしいベイズモデリング2」の心理学研究に用いられたベイズモデルを PyMC Ver.5で描いて分析します。
この書籍をはじめ、多くのベイズモデルは R言語+Stanで書かれています。
PyMCの可能性を探り出し、手軽にベイズモデリングを実践できるように努めます。
身近なテーマ、イメージしやすいテーマですので、ぜひぜひPyMCで動かして、一緒に楽しみましょう!
3.実験!岩波データサイエンス1のベイズモデリングをPyMC Ver.5で
書籍「実験!岩波データサイエンスvol.1」の4人のベイジアンによるベイズモデルを PyMC Ver.5で描いて分析します。
この書籍はベイズプログラミングのイロハをざっくりと学ぶことができる良書です。
楽しくPyMCモデルを動かして、ベイズと仲良しになれた気がします。
みなさんもぜひぜひPyMCで動かして、一緒に遊んで学びましょう!
4.楽しい写経 ベイズ・Python等
ベイズ、Python、その他の「書籍の写経活動」の成果をブログにします。
主にPythonへの翻訳に取り組んでいます。
写経に取り組むお仲間さんのサンプルコードになれば幸いです🍀
5.RとStanではじめる心理学のための時系列分析入門 を PythonとPyMC Ver.5 で
書籍「RとStanではじめる心理学のための時系列分析入門」の時系列分析をPythonとPyMC Ver.5 で実践します。
この書籍には時系列分析のテーマが盛りだくさん!
時系列分析の懐の深さを実感いたしました。
大好きなPythonで楽しく時系列分析を学びます。
6.データサイエンスっぽいことを綴る
統計、データ分析、AI、機械学習、Pythonのコラムを不定期に綴っています。
統計・データサイエンス書籍にまつわる記事が多いです。
「統計」「Python」「数学とPython」「R」のシリーズが生まれています。
7.Python機械学習プログラミング実践記
書籍「Python機械学習プログラミング PyTorch & scikit-learn編」を学んだときのさまざまな思いを記事にしました。
この書籍は、scikit-learnとPyTorchの教科書です。
よかったらぜひ、お試しくださいませ。
最後までお読みいただきまして、ありがとうございました。