StanとRでベイズ統計モデリングをPyMC Ver.5で写経~第8章「8.1 階層モデルの導入」
第8章「階層モデル」
書籍の著者 松浦健太郎 先生
この記事は、テキスト第8章「階層モデル」の8.1節「階層モデルの導入」の PyMC5写経 を取り扱います。
線形モデルの切片と傾きに「会社別」というグループ(くくり)を導入します。
はじめに
StanとRでベイズ統計モデリングの紹介
この記事は書籍「StanとRでベイズ統計モデリング」(共立出版、「テキスト」と呼びます)のベイズモデルを用いて、PyMC Ver.5で「実験的」に写経する翻訳的ドキュメンタリーです。
テキストは、2016年10月に発売され、ベイズモデリングのモデル式とプログラミングに関する丁寧な解説とモデリングの改善ポイントを網羅するチュートリアル「実践解説書」です。もちろん素晴らしいです!
「アヒル本」の愛称で多くのベイジアンに愛されてきた書籍です!
テキストに従ってStanとRで実践する予定でしたが、RのStan環境を整えることができませんでした(泣)
そこでこのシリーズは、テキストのベイズモデルをPyMC Ver.5に書き換えて実践します。
引用表記
この記事は、出典に記載の書籍に掲載された文章及びコードを引用し、適宜、掲載文章とコードを改変して書いています。
【出典】
「StanとRでベイズ統計モデリング」初版第13刷、著者 松浦健太郎、共立出版
記事中のイラストは、「かわいいフリー素材集いらすとや」さんのイラストをお借りしています。
ありがとうございます!
PyMC環境の準備
Anacondaを用いる環境構築とGoogle ColaboratoryでPyMCを動かす方法について、次の記事にまとめています。
「PyMCを動かすまでの準備」章をご覧ください。
8.1 階層モデルの導入
インポート
### インポート
# 数値・確率計算
import pandas as pd
import numpy as np
import scipy.stats as stats
# PyMC
import pymc as pm
import arviz as az
# 描画
import matplotlib.pyplot as plt
# import matplotlib.cm as cm
import seaborn as sns
plt.rcParams['font.family'] = 'Meiryo'
# ワーニング表示の抑制
import warnings
warnings.simplefilter('ignore')
データの読み込み
サンプルコードのデータを読み込みます。
### データの読み込み ◆データファイル8.1 data-salary-2.txt
# X:年齢-23, Y:年収, KID:勤務会社ID
data = pd.read_csv('./data/data-salary-2.txt')
print('data.shape: ', data.shape)
display(data.head())
【実行結果】
8.1.1 解析の目的とデータの分布の確認
散布図で可視化します。
テキスト図8.1の左右のグラフに相当します。
まず左側の「すべてのデータをまとめてプロット」です。
単回帰の傾き・切片は scipy.stats の linregress で求めます。
### 散布図の描画 ◆図8.1左
## 単回帰分析の実行
#傾きと切片を取得
slope, intercept, _, _, _ = stats.linregress(x=data.X, y=data.Y)
print(f'傾き: {slope:.3f}, 切片: {intercept:.3f}')
# 回帰直線描画用のxとyの算出
x_lm = np.linspace(data.X.min() - 1, data.X.max() + 1, 2)
y_lm = intercept + slope * x_lm
## 描画処理
# マーカーの設定
markers = {1: 'o', 2: '^', 3: 'X', 4: 'd'}
# 描画領域の設定
plt.figure(figsize=(5, 4))
ax = plt.subplot()
# 回帰直線の描画
ax.plot(x_lm, y_lm, color='black', lw=3, alpha=0.4)
# 散布図の描画
sns.scatterplot(ax=ax, data=data, x='X', y='Y', hue='KID', style='KID',
s=100, markers=markers, palette='tab10', alpha=0.8)
# 修飾
ax.set(xlabel='年齢 $X$ [-23歳]', ylabel='年収 $Y$ [万円]',
title='年齢 $X$ と年収 $Y$ の散布図')
ax.legend(bbox_to_anchor=(1, 1), title='会社ID')
ax.grid(lw=0.5)
plt.show()
【実行結果】
上記グラフの凡例を会社別のグラフで使う目的で、凡例データを保存します。
### 会社別グラフで用いる凡例データを保存
handles, labels = ax.get_legend_handles_labels()
続いて図8.1右の「会社別ごとに分割してプロット」です。
### 会社別散布図の描画 ◆図8.1右
# 描画領域の指定
fig, axes = plt.subplots(2, 2, figsize=(6, 6), sharex=True, sharey=True)
# 会社IDごとに繰り返し描画処理(処理的にはaxesごとに繰り返し処理)
for i, ax in enumerate(axes.ravel()):
## 描画用データの作成
# 会社を1つ取り出す
tmp_df = data[data['KID'] == i + 1]
# 当該会社のデータで回帰直線の傾きと切片を取得
slope2, intercept2, _, _ ,_ = stats.linregress(x=tmp_df.X, y=tmp_df.Y)
# 当該会社の回帰直線描画用のxとyの算出
x_lm2 = np.linspace(tmp_df.X.min(), tmp_df.X.max(), 2)
y_lm2 = intercept2 + slope2 * x_lm2
## 描画処理
# 全会社の回帰直線の描画
ax.plot(x_lm, y_lm, color='black', lw=3, alpha=0.4)
# 当該会社の回帰直線の描画
ax.plot(x_lm2, y_lm2, color='red', lw=2, ls='--')
# 散布図の描画
sns.scatterplot(ax=ax, data=tmp_df, x='X', y='Y', style='KID',
s=100, markers=markers, color=plt.cm.tab10(i/10), alpha=0.8,
legend=None)
# 修飾
ax.set(xlabel=None, ylabel=None, title=f'会社ID: {i+1}')
ax.grid(lw=0.3)
# 図8.1左のグラフから取得した凡例を表示
fig.legend(handles=handles, labels=labels, bbox_to_anchor=(1.1, 0.9),
title='会社ID')
# 全体修飾
fig.supxlabel('年齢 $X$ [-23歳]')
fig.supylabel('年収 $Y$ [万円]')
fig.suptitle('年齢 $X$ と年収 $Y$ の散布図: 会社別')
plt.tight_layout();
【実行結果】
赤い点線は各会社の単回帰直線です。
8.1.2 グループ差を考えない場合 モデル式8-1
モデルの定義です。
### モデルの定義 ◆モデル式8-1 model8-1.stan
with pm.Model() as model1:
### データ関連定義
## coordの定義
model1.add_coord('data', values=data.index, mutable=True)
## dataの定義
# 目的変数 Y
Y = pm.ConstantData('Y', value=data['Y'].values, dims='data')
# 説明変数 X
X = pm.ConstantData('X', value=data['X'].values, dims='data')
### 事前分布
a = pm.Uniform('a', lower=-1000, upper=1000)
b = pm.Uniform('b', lower=-1000, upper=1000)
sigma = pm.Uniform('sigma', lower=0, upper=1000)
### 線形予測子
mu = pm.Deterministic('mu', a + b * X, dims='data')
### 尤度関数
obs = pm.Normal('obs', mu=mu, sigma=sigma, observed=Y, dims='data')
モデルの定義内容を見ます。
### モデルの表示
model1
【実行結果】
### モデルの可視化
pm.model_to_graphviz(model1)
【実行結果】
MCMCを実行します。
### 事後分布からのサンプリング 20秒
with model1:
idata1 = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.8,
nuts_sampler='numpyro', random_seed=1234)
【実行結果】省略
Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。
### r_hat>1.1の確認
# 設定
idata_in = idata1 # idata名
threshold = 1.01 # しきい値
# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())
【実行結果】
収束条件を満たしています。
事後統計量を表示します。
### 推論データの要約統計情報の表示
var_names = ['a', 'b', 'sigma']
pm.summary(idata1, hdi_prob=0.95, var_names=var_names, round_to=3)
【実行結果】
トレースプロットを描画します。
### トレースプロットの表示
var_names = ['a', 'b', 'sigma', 'mu']
pm.plot_trace(idata1, compact=True, var_names=var_names)
plt.tight_layout();
【実行結果】
パラメータの事後統計量の要約を算出します。
要約統計量算出関数を定義します。
### mean,sd,2.5%,25%,50%,75%,97.5%パーセンタイル点をデータフレーム化する関数の定義
def make_stats_df(y):
probs = [2.5, 25, 50, 75, 97.5]
columns = ['mean', 'sd'] + [str(s) + '%' for s in probs]
quantiles = pd.DataFrame(np.percentile(y, probs, axis=0).T, index=y.columns)
tmp_df = pd.concat([y.mean(axis=0), y.std(axis=0), quantiles], axis=1)
tmp_df.columns=columns
return tmp_df
要約統計量を算出します。
### 要約統計量の算出・表示
vars = ['a', 'b', 'sigma']
param_samples = idata1.posterior[vars].to_dataframe().reset_index(drop=True)
display(make_stats_df(param_samples).round(1))
【実行結果】
$${Y}$$の観測値と予測値のプロットを描画します。
### Yの観測値と予測値をプロット
## 描画用データの作成
# 推論データからパラメータa, b, sigmaを取り出し
a_samples = idata1.posterior.a.stack(sample=('chain', 'draw')).data
b_samples = idata1.posterior.b.stack(sample=('chain', 'draw')).data
sigma_samples = idata1.posterior.sigma.stack(sample=('chain', 'draw')).data
# Yの予測値算出に用いるXの値の算出
xvals = np.linspace(data.X.min() - 1, data.X.max() + 1, 101)
# Yの事後予測データの算出
y_pred_samples = np.array([stats.norm.rvs(loc=a_sample + b_sample * xvals,
scale=sigma_sample)
for (a_sample, b_sample, sigma_sample)
in zip(a_samples, b_samples, sigma_samples)])
# Yの事後予測データの中央値、95%CI、50%CIの算出
y_pred_median = np.median(y_pred_samples, axis=0)
y_pred_95ci = np.quantile(y_pred_samples, q=[0.025, 0.975], axis=0)
y_pred_50ci = np.quantile(y_pred_samples, q=[0.250, 0.750], axis=0)
## 描画処理
# 描画領域の設定
plt.figure(figsize=(5, 4))
ax = plt.subplot()
# 観測値の散布図の描画
sns.scatterplot(ax=ax, data=data, x='X', y='Y', hue='KID', style='KID',
s=100, markers=markers, palette='tab10', alpha=0.8)
# 予測値の中央値の描画
ax.plot(xvals, y_pred_median, color='tomato')
# 予測値の95%CIと50%CIの描画
ax.fill_between(xvals, y_pred_95ci[0], y_pred_95ci[1], color='tomato', alpha=0.2)
ax.fill_between(xvals, y_pred_50ci[0], y_pred_50ci[1], color='tomato', alpha=0.5)
# 修飾
ax.set(xlabel='年齢 $X$ [-23歳]', ylabel='年収 $Y$ [万円]',
title='年齢 $X$ と年収 $Y$ の散布図')
ax.legend(bbox_to_anchor=(1, 1), title='会社ID')
ax.grid(lw=0.5)
plt.show()
【実行結果】
8.1.3 グループごとに切片と傾きを持つ場合 モデル式8-2
会社ごとの切片$${a}$$と傾き$${b}$$を設定するモデルです。
0始まりの会社インデックスを作成します。
### 会社インデックスの作成(0始まりにする)
k_idx = data.KID - 1
モデルの定義です。
### モデルの定義 ◆モデル式8-2 model8-2.stan
with pm.Model() as model2:
### データ関連定義
## coordの定義
model2.add_coord('data', values=data.index, mutable=True)
model2.add_coord('kaisha', values=sorted(data.KID.unique()), mutable=True)
## dataの定義
# 目的変数 Y
Y = pm.ConstantData('Y', value=data['Y'].values, dims='data')
# 説明変数 X
X = pm.ConstantData('X', value=data['X'].values, dims='data')
# 説明変数 KIdx 会社インデックス
KIdx = pm.ConstantData('KIdx', value=k_idx.values, dims='data')
### 事前分布
a = pm.Uniform('a', lower=-10000, upper=10000, dims='kaisha')
b = pm.Uniform('b', lower=-10000, upper=10000, dims='kaisha')
sigma = pm.Uniform('sigma', lower=0, upper=10000)
### 線形予測子
mu = pm.Deterministic('mu', a[KIdx] + b[KIdx] * X, dims='data')
### 尤度関数
obs = pm.Normal('obs', mu=mu, sigma=sigma, observed=Y, dims='data')
モデルの定義内容を見ます。
### モデルの表示
model2
【実行結果】
### モデルの可視化
pm.model_to_graphviz(model2)
【実行結果】
MCMCを実行します。
### 事後分布からのサンプリング 10秒
with model2:
idata2 = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.8,
nuts_sampler='numpyro', random_seed=1234)
【実行結果】省略
Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。
### r_hat>1.1の確認
# 設定
idata_in = idata2 # idata名
threshold = 1.01 # しきい値
# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())
【実行結果】
収束条件を満たしています。
事後統計量を表示します。
### 推論データの要約統計情報の表示
var_names = ['a', 'b', 'sigma']
pm.summary(idata2, hdi_prob=0.95, var_names=var_names, round_to=3)
【実行結果】
トレースプロットを描画します。
### トレースプロットの表示
var_names = ['a', 'b', 'sigma', 'mu']
pm.plot_trace(idata2, compact=True, var_names=var_names)
plt.tight_layout();
【実行結果】
会社4(赤色)のバラツキが大きいです。
パラメータの事後統計量の要約を算出します。
### パラメータの要約を確認
# パラメータのMCMCサンプリングデータをデータフレーム化
a_samples_df2 = pd.DataFrame(
idata2.posterior.a.stack(sample=('chain', 'draw')).data.T,
columns=[f'a[{i+1}]'for i in range(data.KID.nunique())])
b_samples_df2 = pd.DataFrame(
idata2.posterior.b.stack(sample=('chain', 'draw')).data.T,
columns=[f'b[{i+1}]' for i in range(data.KID.nunique())])
sigma_samples_df2 = pd.DataFrame(
idata2.posterior.sigma.stack(sample=('chain', 'draw')).data.T,
columns=['sigma'])
param_samples = pd.concat([a_samples_df2, b_samples_df2, sigma_samples_df2],
axis=1)
# 要約統計量の算出・表示
display(make_stats_df(param_samples).round(1))
【実行結果】
8.1.4 階層モデル モデル式8-3
切片$${a}$$と傾き$${b}$$を「すべての会社に共通する全体平均」と「会社ごとの差分」に分けて設定するモデルです。
後者の会社ごとの差分パラメータは平均0、標準偏差$${\sigma_a}$$等の正規分布に従うというゆるい制約を入れます。
テキスト冒頭のシミュレーションを行います。
### シミュレーション sim-model8-3.R
## 初期値設定
seed = 0 # 乱数シード
N = 40 # 標本サイズ
K = 4 # 会社数
N_k = [15, 12, 10, 3] # 会社IDごとの社員数
a0 = 350 # 切片aの全体平均
b0 = 12 # 傾きbの全体平均
s_a = 60 # 切片aの会社差が従う正規分布の標準偏差
s_b = 4 # 傾きbの会社差が従う正規分布の標準偏差
s_Y = 25 # 年収Yの従う正規分布の標準偏差
## シミュレーションデータの生成
# 乱数生成器の設定
rng = np.random.default_rng(seed=seed)
# 説明変数X:年齢
X = rng.choice(np.arange(0, 36), size=N, replace=True)
# 説明変数KID:会社ID
KID = np.repeat(np.arange(1, 5), N_k)
# 切片a 会社差(正規分布乱数)+全体平均
a = rng.normal(loc=0, scale=s_a, size=K) + a0
# 傾きb 会社差(正規分布乱数)+全体平均
b = rng.normal(loc=0, scale=s_b, size=K) + b0
# 目的変数Y:年収 ※モデル式8-3の式8.1(Y[n])と同じ
Y = rng.normal(loc=a[KID-1] + b[KID-1] * X, scale=s_Y, size=N)
## データのまとめ
# データフレーム化
d = pd.DataFrame({'X': X, 'KID': KID, 'a': a[KID-1], 'b': b[KID-1], 'Y': Y})
# データフレームの表示
print('d.shape: ', d.shape)
display(d.head())
display(d.tail())
【実行結果】
生成したデータの先頭5つ、末尾5つです。
会社別散布図を描画します。
テキスト図8.2に相当します。
### 会社別散布図の描画 ◆図8.2
# 描画領域の指定
fig, axes = plt.subplots(2, 2, figsize=(6, 6), sharex=True, sharey=True)
# 会社IDごとに繰り返し描画処理(処理的にはaxesごとに繰り返し処理)
for i, ax in enumerate(axes.ravel()):
## 描画用データの作成
# 会社を1つ取り出す
tmp_df = d[d['KID'] == i + 1]
# 当該会社のデータで回帰直線の傾きと切片を取得
slope2, intercept2, _, _ ,_ = stats.linregress(x=tmp_df.X, y=tmp_df.Y)
# 当該会社の回帰直線描画用のxとyの算出
x_lm2 = np.linspace(tmp_df.X.min(), tmp_df.X.max(), 2)
y_lm2 = intercept2 + slope2 * x_lm2
## 描画処理
# 当該会社の回帰直線の描画
ax.plot(x_lm2, y_lm2, color='red', lw=2, ls='--')
# 散布図の描画
sns.scatterplot(ax=ax, data=tmp_df, x='X', y='Y', style='KID',
s=100, markers=markers, color=plt.cm.tab10(i/10), alpha=0.8,
legend=None)
# 修飾
ax.set(xlabel=None, ylabel=None, title=f'会社ID: {i+1}')
ax.grid(lw=0.3)
# 図8.1左のグラフから取得した凡例を表示
fig.legend(handles=handles, labels=labels, bbox_to_anchor=(1.1, 0.9),
title='会社ID')
# 全体修飾
fig.supxlabel('年齢 $X$ [-23歳]')
fig.supylabel('年収 $Y$ [万円]')
fig.suptitle('年齢 $X$ と年収 $Y$ の散布図: 会社別')
plt.tight_layout();
【実行結果】
乱数生成結果がテキストと異なるので、このグラフもテキストと相違します。
モデルの定義です。
$${a0,\ b0}$$が「すべての会社に共通する全体平均」、$${ak,\ bk}$$が「会社ごとの差分」です。
### モデルの定義 ◆モデル式8-3 model8-3.stan
with pm.Model() as model3:
### データ関連定義
## coordの定義
model3.add_coord('data', values=data.index, mutable=True)
model3.add_coord('kaisha', values=sorted(data.KID.unique()), mutable=True)
## dataの定義
# 目的変数 Y
Y = pm.ConstantData('Y', value=data['Y'].values, dims='data')
# 説明変数 X
X = pm.ConstantData('X', value=data['X'].values, dims='data')
# 説明変数 KIdx 会社インデックス
KIdx = pm.ConstantData('KIdx', value=k_idx.values, dims='data')
### 事前分布
a0 = pm.Uniform('a0', lower=-10000, upper=10000)
b0 = pm.Uniform('b0', lower=-10000, upper=10000)
sigmaA = pm.Uniform('sigmaA', lower=0, upper=1000)
sigmaB = pm.Uniform('sigmaB', lower=0, upper=1000)
sigmaY = pm.Uniform('sigmaY', lower=0, upper=10000)
ak = pm.Normal('ak', mu=0, sigma=sigmaA, dims='kaisha')
bk = pm.Normal('bk', mu=0, sigma=sigmaB, dims='kaisha')
### 線形予測子
a = pm.Deterministic('a', a0 + ak, dims='kaisha')
b = pm.Deterministic('b', b0 + bk, dims='kaisha')
mu = pm.Deterministic('mu', a[KIdx] + b[KIdx] * X, dims='data')
### 尤度関数
obs = pm.Normal('obs', mu=mu, sigma=sigmaY, observed=Y, dims='data')
モデルの定義内容を見ます。
### モデルの表示
model3
【実行結果】
### モデルの可視化
pm.model_to_graphviz(model3)
【実行結果】
会社別パラメータ$${ak,\ bk}$$の従う正規分布と、その正規分布の標準偏差パラメータ$${\sigma_a,\ \sigma_b}$$が従う一様分布の2つの事前分布が、階層的に設定されています。
MCMCを実行します。
### 事後分布からのサンプリング 15秒
with model3:
idata3 = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.8,
nuts_sampler='numpyro', random_seed=1234)
【実行結果】省略
Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。
### r_hat>1.1の確認
# 設定
idata_in = idata3 # idata名
threshold = 1.01 # しきい値
# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())
【実行結果】
収束条件を満たしています。
事後統計量を表示します。
### 推論データの要約統計情報の表示
var_names = ['a0', 'b0', 'sigmaA', 'sigmaB', 'sigmaY', 'ak', 'bk']
pm.summary(idata3, hdi_prob=0.95, var_names=var_names, round_to=3)
【実行結果】
トレースプロットを描画します。
### トレースプロットの表示
pm.plot_trace(idata3, compact=True, var_names=var_names)
plt.tight_layout();
【実行結果】
パラメータの事後統計量の要約を算出します。
2つに分けて実行します。
### 要約統計量の算出・表示
vars = ['a0', 'b0', 'sigmaA', 'sigmaB', 'sigmaY']
param_samples = idata3.posterior[vars].to_dataframe().reset_index(drop=True)
display(make_stats_df(param_samples).round(1))
【実行結果】
### 要約統計量の算出・表示
# パラメータのMCMCサンプリングデータをデータフレーム化
ak_samples_df3 = pd.DataFrame(
idata3.posterior.ak.stack(sample=('chain', 'draw')).data.T,
columns=[f'ak[{i+1}]'for i in range(data.KID.nunique())])
bk_samples_df3 = pd.DataFrame(
idata3.posterior.bk.stack(sample=('chain', 'draw')).data.T,
columns=[f'bk[{i+1}]' for i in range(data.KID.nunique())])
param_samples = pd.concat([ak_samples_df3, bk_samples_df3], axis=1)
# 要約統計量の算出・表示
display(make_stats_df(param_samples).round(1))
【実行結果】
8.1.5 モデルの比較
モデル式8-2と8-3で推定したパラメータ$${a[k]}$$の差異をプロットします。
テキスト図8.4左に相当します。
### モデル式8-2と8-3で推定したa[k]の差異 ◆図8.4左
## 中央値、95%CIと中央値の差を算出する関数の定義
def calc_stats(x):
med = np.median(x, axis=1)
ci95 = np.quantile(x, q=[0.025, 0.975], axis=1)
return med, med - ci95[0], ci95[1] - med
## 描画用データの作成
# model8-1のαの中央値
med_8_1 = idata1.posterior.a.median().data
# model8-2のαの中央値, 95%CI
med_8_2, low_8_2, high_8_2 = calc_stats(idata2.posterior.a
.stack(sample=('chain', 'draw')).data)
# model8-3のαの中央値, 95%CI
med_8_3, low_8_3, high_8_3 = calc_stats(idata3.posterior.a
.stack(sample=('chain', 'draw')).data)
## 描画処理
# 描画用パラメータの設定
width = 0.1
xticks = np.arange(1, 5)
# 描画領域の設定
plt.figure(figsize=(5, 4))
ax = plt.subplot()
# model8-1のαの中央値の水平線の描画
ax.axhline(med_8_1, color='gray', lw=2, label='モデル8-1')
# model8-2のαのエラーバーの描画
ax.errorbar(xticks - width, med_8_2, yerr=[low_8_2, high_8_2], fmt='o',
ms=12, color='tab:red', elinewidth=2, label='モデル8-2')
# model8-3のαのエラーバーの描画
ax.errorbar(xticks + width, med_8_3, yerr=[low_8_3, high_8_3], fmt='o',
ms=12, color='tab:blue', elinewidth=2, label='モデル8-3')
# 修飾
ax.set(xlabel='会社ID', ylabel=r'パラメータ $\alpha$', title=r'$\alpha$ の差異',
xticks=xticks)
ax.grid(lw=0.5)
ax.legend(bbox_to_anchor=(1.33, 1), title='Model');
plt.show()
【実行結果】
続いて会社別散布図を描画します。
テキスト図8.4右に相当します。
### 会社別散布図の描画 ◆図8.4右
## model8-1の描画用データの作成
a_1 = idata1.posterior.a.data.flatten()
b_1 = idata1.posterior.b.data.flatten()
x_1 = np.linspace(data.X.min() - 1, data.X.max() + 1, 101)
mu_1_med = np.median(np.array([a + b * x_1 for (a, b) in zip(a_1, b_1)]), axis=0)
## model8-2の描画用データの作成
a_2 = idata2.posterior.a.stack(sample=('chain', 'draw')).data
b_2 = idata2.posterior.b.stack(sample=('chain', 'draw')).data
## model8-3の描画用データの作成
a_3 = idata3.posterior.a.stack(sample=('chain', 'draw')).data
b_3 = idata3.posterior.b.stack(sample=('chain', 'draw')).data
# 描画領域の指定
fig, axes = plt.subplots(2, 2, figsize=(6, 6), sharex=True, sharey=True)
# 会社IDごとに繰り返し描画処理(処理的にはaxesごとに繰り返し処理)
for i, ax in enumerate(axes.ravel()):
## 描画用データの作成
# 会社を1つ取り出す
tmp_df = data[data['KID'] == i + 1]
# 当該会社の回帰直線描画用のxとyの算出
x_23 = np.linspace(tmp_df.X.min(), tmp_df.X.max(), 101)
mu_2_med = np.median(np.array([a + b * x_23 for (a, b)
in zip(a_2[i], b_2[i])]), axis=0)
mu_3_med = np.median(np.array([a + b * x_23 for (a, b)
in zip(a_3[i], b_3[i])]), axis=0)
## 描画処理
# 散布図の描画
sns.scatterplot(ax=ax, data=tmp_df, x='X', y='Y', style='KID',
s=100, markers=markers, color=plt.cm.tab10(i/10), alpha=0.8,
legend=None)
# model8-1の中央値の描画
ax.plot(x_1, mu_1_med, color='gray', lw=2)
# model8-2の当該会社の中央値の描画
ax.plot(x_23, mu_2_med, color='tab:red', lw=1.5)
# 修飾
# model8-3の当該会社の中央値の描画
ax.plot(x_23, mu_3_med, color='tab:green', lw=3, ls=':')
ax.set(xlabel=None, ylabel=None, title=f'会社ID: {i+1}')
ax.grid(lw=0.3)
# 全体修飾
fig.supxlabel('年齢 $X$ [-23歳]')
fig.supylabel('年収 $Y$ [万円]')
fig.suptitle('年齢 $X$ と年収 $Y$ の散布図: 会社別')
# Modelの凡例の作成・表示
plt.plot([None], [None], color='gray', lw=2, label='8-1')
plt.plot([None], [None], color='tab:red', lw=1.5, label='8-2')
plt.plot([None], [None], color='tab:green', lw=3, ls=':', label='8-3')
fig.legend(bbox_to_anchor=(1.01, 0.72), title='Model')
# 図8.1左のグラフから取得した会社IDの凡例を表示
plt.legend(handles=handles, labels=labels, bbox_to_anchor=(1, 1.2),
title='会社ID')
plt.tight_layout();
【実行結果】
会社ID=4のケースでは、傾き(緑の点線)がデータ全体から推定した傾き(グレイの線)に近くなっています。
会社4データへの過学習を抑えているようです。
8.1.6 階層モデルの等価な表現 モデル式8-4
モデル式8-3と等価な別モデル・モデル式8-4を試します。
モデルの定義です。
$${a,\ b}$$の箇所に変更があります。
### モデルの定義 ◆モデル式8-4 model8-4.stan
with pm.Model() as model4:
### データ関連定義
## coordの定義
model4.add_coord('data', values=data.index, mutable=True)
model4.add_coord('kaisha', values=sorted(data.KID.unique()), mutable=True)
## dataの定義
# 目的変数 Y
Y = pm.ConstantData('Y', value=data['Y'].values, dims='data')
# 説明変数 X
X = pm.ConstantData('X', value=data['X'].values, dims='data')
# 説明変数 KIdx 会社インデックス
KIdx = pm.ConstantData('KIdx', value=k_idx.values, dims='data')
### 事前分布
a0 = pm.Uniform('a0', lower=-10000, upper=10000)
b0 = pm.Uniform('b0', lower=-10000, upper=10000)
sigmaA = pm.Uniform('sigmaA', lower=0, upper=1000)
sigmaB = pm.Uniform('sigmaB', lower=0, upper=1000)
sigmaY = pm.Uniform('sigmaY', lower=0, upper=10000)
### 線形予測子
a = pm.Normal('a', mu=a0, sigma=sigmaA, dims='kaisha')
b = pm.Normal('b', mu=b0, sigma=sigmaB, dims='kaisha')
mu = pm.Deterministic('mu', a[KIdx] + b[KIdx] * X, dims='data')
### 尤度関数
obs = pm.Normal('obs', mu=mu, sigma=sigmaY, observed=Y, dims='data')
モデルの定義内容を見ます。
### モデルの表示
model4
【実行結果】
### モデルの可視化
pm.model_to_graphviz(model4)
【実行結果】
MCMCを実行します。
### 事後分布からのサンプリング 15秒
with model4:
idata4 = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.988,
nuts_sampler='numpyro', random_seed=1234)
【実行結果】省略
Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。
### r_hat>1.1の確認
# 設定
idata_in = idata4 # idata名
threshold = 1.01 # しきい値
# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())
【実行結果】
収束条件を満たしています。
事後統計量を表示します。
### 推論データの要約統計情報の表示
var_names = ['a0', 'b0', 'sigmaA', 'sigmaB', 'sigmaY']
pm.summary(idata4, hdi_prob=0.95, var_names=var_names, round_to=3)
【実行結果】
トレースプロットを描画します。
### トレースプロットの表示
var_names = ['a0', 'b0', 'sigmaA', 'sigmaB', 'sigmaY', 'a', 'b', 'mu']
pm.plot_trace(idata4, compact=True, var_names=var_names)
plt.tight_layout();
【実行結果】
発散(バーコード)が見られます。
パラメータの事後統計量の要約を算出します。
### 要約統計量の算出・表示
vars = ['a0', 'b0', 'sigmaA', 'sigmaB', 'sigmaY']
param_samples = idata4.posterior[vars].to_dataframe().reset_index(drop=True)
display(make_stats_df(param_samples).round(1))
【実行結果】
(参考:モデル式8-3の要約統計量)
8.1 節は以上です。
シリーズの記事
次の記事
前の記事
目次
ブログの紹介
note で7つのシリーズ記事を書いています。
ぜひ覗いていってくださいね!
1.のんびり統計
統計検定2級の問題集を手がかりにして、確率・統計をざっくり掘り下げるブログです。
雑談感覚で大丈夫です。ぜひ覗いていってくださいね。
統計検定2級公式問題集CBT対応版に対応しています。
Python、EXCELのサンプルコードの配布もあります。
2.実験!たのしいベイズモデリング1&2をPyMC Ver.5で
書籍「たのしいベイズモデリング」・「たのしいベイズモデリング2」の心理学研究に用いられたベイズモデルを PyMC Ver.5で描いて分析します。
この書籍をはじめ、多くのベイズモデルは R言語+Stanで書かれています。
PyMCの可能性を探り出し、手軽にベイズモデリングを実践できるように努めます。
身近なテーマ、イメージしやすいテーマですので、ぜひぜひPyMCで動かして、一緒に楽しみましょう!
3.実験!岩波データサイエンス1のベイズモデリングをPyMC Ver.5で
書籍「実験!岩波データサイエンスvol.1」の4人のベイジアンによるベイズモデルを PyMC Ver.5で描いて分析します。
この書籍はベイズプログラミングのイロハをざっくりと学ぶことができる良書です。
楽しくPyMCモデルを動かして、ベイズと仲良しになれた気がします。
みなさんもぜひぜひPyMCで動かして、一緒に遊んで学びましょう!
4.楽しい写経 ベイズ・Python等
ベイズ、Python、その他の「書籍の写経活動」の成果をブログにします。
主にPythonへの翻訳に取り組んでいます。
写経に取り組むお仲間さんのサンプルコードになれば幸いです🍀
5.RとStanではじめる心理学のための時系列分析入門 を PythonとPyMC Ver.5 で
書籍「RとStanではじめる心理学のための時系列分析入門」の時系列分析をPythonとPyMC Ver.5 で実践します。
この書籍には時系列分析のテーマが盛りだくさん!
時系列分析の懐の深さを実感いたしました。
大好きなPythonで楽しく時系列分析を学びます。
6.データサイエンスっぽいことを綴る
統計、データ分析、AI、機械学習、Pythonのコラムを不定期に綴っています。
統計・データサイエンス書籍にまつわる記事が多いです。
「統計」「Python」「数学とPython」「R」のシリーズが生まれています。
7.Python機械学習プログラミング実践記
書籍「Python機械学習プログラミング PyTorch & scikit-learn編」を学んだときのさまざまな思いを記事にしました。
この書籍は、scikit-learnとPyTorchの教科書です。
よかったらぜひ、お試しくださいませ。
最後までお読みいただきまして、ありがとうございました。
この記事が参加している募集
この記事が気に入ったらサポートをしてみませんか?