見出し画像

StanとRでベイズ統計モデリングをPyMC Ver.5で写経~第12章「12.8 地図を使った空間構造」

第12章「時間や空間を扱うモデル」

書籍の著者 松浦健太郎 先生


この記事は、テキスト第12章「時間や空間を扱うモデル」の 12.8節「地図を使った空間構造」の PyMC5写経 を取り扱います。
1次元の空間構造で対応可能であり、PyMC の ICAR() を用いました。
たぶん、テキストのモデルと異なっています。

はじめに


StanとRでベイズ統計モデリングの紹介

この記事は書籍「StanとRでベイズ統計モデリング」(共立出版、「テキスト」と呼びます)のベイズモデルを用いて、PyMC Ver.5で「実験的」に写経する翻訳的ドキュメンタリーです。

テキストは、2016年10月に発売され、ベイズモデリングのモデル式とプログラミングに関する丁寧な解説とモデリングの改善ポイントを網羅するチュートリアル「実践解説書」です。もちろん素晴らしいです!
アヒル本」の愛称で多くのベイジアンに愛されてきた書籍です!

テキストに従ってStanとRで実践する予定でしたが、RのStan環境を整えることができませんでした(泣)
そこでこのシリーズは、テキストのベイズモデルをPyMC Ver.5に書き換えて実践します。

引用表記

この記事は、出典に記載の書籍に掲載された文章及びコードを引用し、適宜、掲載文章とコードを改変して書いています。
【出典】
「StanとRでベイズ統計モデリング」初版第13刷、著者 松浦健太郎、共立出版

記事中のイラストは、「かわいいフリー素材集いらすとや」さんのイラストをお借りしています。
ありがとうございます!

PyMC環境の準備

Anacondaを用いる環境構築とGoogle ColaboratoryでPyMCを動かす方法について、次の記事にまとめています。
「PyMCを動かすまでの準備」章をご覧ください。


12.8 地図を使った空間構造


モデリングの準備

インポート
日本地図の描画には japanmap パッケージを利用します。

### インポート

# 数値・確率計算
import pandas as pd
import numpy as np
import scipy.stats as stats

# PyMC
import pymc as pm
import arviz as az

# 描画
import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['font.family'] = 'Meiryo'

# 日本地図
# !pip install japanmap
from japanmap import picture

# ワーニング表示の抑制
import warnings
warnings.simplefilter('ignore')

データの読み込みと確認
サンプルコードのデータを読み込みます。
2013年の都道府県別年平均気温データです。
都道府県コードと気温のデータです。

### データの読み込み ◆データファイル12.3 data-map-temperature.txt
# prefID:都道府県コード(JIS), Y:2013年の年平均気温[℃]

data1 = pd.read_csv('./data/data-map-temperature.txt', index_col=0)
print('data1.shape: ', data1.shape)
display(data1.head())

【実行結果】

こちらは隣り合う都道府県のコードペアデータです。

### データの読み込み ◆データファイル12.4 data-map-neighbor.txt
# 隣り合う都道府県について、都道府県コードの組み合わせ(From < To)

data2 = pd.read_csv('./data/data-map-neighbor.txt')
print('data2.shape: ', data2.shape)
display(data2.head())

【実行結果】

要約統計量を算出します。

### 要約統計量の表示
data1.describe().round(1)

【実行結果】
最低 9.2 ℃、最高 23.3 ℃です。

年平均気温データを日本地図に描画します。
テキスト図10.10左に相当します。

### 年平均気温の日本地図への描画 ◆図12.10左

# 色分けするカラーマップの設定
cmap = plt.get_cmap('Reds')
# 正規化のための気温の最大値・最小値を設定
norm = plt.Normalize(vmin=data1['Y'].min(), vmax=data1['Y'].max())
# 気温の値を色の値に変換するlambda関数の設定
fcol = lambda x: '#' + bytes(cmap(norm(x), bytes=True)[:3]).hex()
# カラーバーを描画
plt.colorbar(plt.cm.ScalarMappable(norm, cmap), ax=plt.gca(), label='℃')
# 日本地図に年平均気温を描画
plt.imshow(picture(data1['Y'].apply(fcol)))
plt.xticks([])
plt.yticks([]);

【実行結果】
長野県など、お隣と大きな差がある都道府県があって、赤色がまだら模様になっています。

PyMCのモデル定義

PyMCでモデル式12-14を実装します。

隣接行列の作成
ICARで利用する隣接行列を作成します。
data2 の隣接都道府県データを元にして、隣り合う都道府県コードのインデックスに1を立てます。

### 隣接行列Wの作成
W = np.zeros((len(data1), len(data1)))
for (fm, to) in data2.values:
    W[fm-1, to-1] = 1
    W[to-1, fm-1] = 1
print('W.shape: ', W.shape)
print(W)

【実行結果】

PyMCのモデル定義 
切片$${\beta}$$+ICARを用いたランダム効果のモデルです。

### モデルの定義 ◆モデル式12-14 model12-14.stan ※ICAR利用

with pm.Model() as model:

    ### データ関連定義
    # coordの定義
    model.add_coord('data', values=data1.index, mutable=True)
    # dataの定義
    Y = pm.ConstantData('Y', value=data1.Y.values, dims='data')

    ### 事前分布
    # 切片
    beta = pm.StudentT('beta', nu=4, mu=data1['Y'].mean(), sigma=100)
    # 標準偏差
    sigmaR = pm.Uniform('sigmaR', lower=0, upper=1)
    sigmaY = pm.Normal('sigmaY', mu=0, sigma=0.1)
    # 切片β+空間構造を取り入れたランダム効果Sのモデル
    # 非中心配置のパラメータ化:ICARの標準偏差を1にして、結果のSにsigmaRを乗ずる
    S = pm.ICAR('S', W=W, sigma=1, dims='data')
    r = pm.Deterministic('r', beta + S*sigmaR, dims='data')

    ### 尤度
    obs = pm.Normal('obs', mu=r, sigma=sigmaY, observed=Y, dims='data')

モデルの定義内容を見ます。

### モデルの表示
model

【実行結果】

### モデルの可視化
pm.model_to_graphviz(model)

【実行結果】

MCMCの実行と収束確認

MCMCを実行します。

### 事後分布からのサンプリング 2分10秒
with model:
    idata = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.8,
                      nuts_sampler='numpyro', random_seed=1234)

【実行結果】省略

Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。

### r_hat>1.1の確認
# 設定
idata_in = idata          # idata名
threshold = 1.01          # しきい値

# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())

【実行結果】
収束条件を満たしています。

事後統計量を表示します。

### 推論データの要約統計情報の表示
var_names = ['sigmaR', 'sigmaY', 'beta', 'r']
pm.summary(idata, hdi_prob=0.95, var_names=var_names, round_to=3)

【実行結果】

トレースプロットを描画します。

### トレースプロットの表示
pm.plot_trace(idata, compact=True, var_names=var_names)
plt.tight_layout();

【実行結果】
sigmaRの上限1には無理があったかもしれません。
(ただし上限1にしないと収束できませんでした)

推論結果の解釈

推論データを用いてテキスト図12.10右に相当する日本地図を描画します。

### 年平均気温の日本地図への描画 ◆図12.10右

## 描画用データの作成
# 推論データからrのMCMCサンプルデータを取り出し
r_samples = idata.posterior.r.stack(sample=('chain', 'draw')).data
# rの中央値の算出
r_median = np.median(r_samples, axis=1)
# rの中央値をpandasのシリーズ化(都道府県コードをインデックスに設定)
r_series = pd.Series(r_median, name='r', index=range(1, len(r_median)+1))

## 描画処理
# 色分けするカラーマップの設定
cmap = plt.get_cmap('Reds')
# 正規化のための気温の最大値・最小値を設定
norm = plt.Normalize(vmin=data1['Y'].min(), vmax=data1['Y'].max())
# 気温の値を色の値に変換するlambda関数の設定
fcol = lambda x: '#' + bytes(cmap(norm(x), bytes=True)[:3]).hex()
# カラーバーを描画
plt.colorbar(plt.cm.ScalarMappable(norm, cmap), ax=plt.gca(), label='推定値 ℃')
# 日本地図に年平均気温を描画
plt.imshow(picture(r_series.apply(fcol)))
plt.xticks([])
plt.yticks([]);

【実行結果】
気温データそのものを日本地図に描画したグラフと比べると、こちらの方が、色のグラデーションが自然であり、隣接する都道府県の気温が滑らかに表現されています。

テキスト253~254ページの「その他の影響($${Y[n]-r[n]}$$)」の中央値を算出します。

### その他の影響y-rの中央値の算出
# Y-rの算出
diff_samples = np.array([data1['Y'].values[i] - r_samples[i]
                         for i in range(len(r_samples))])
# Y-rの中央値の算出
diff_median = np.median(diff_samples, axis=1)
# 中央値の絶対値の大きい順にインデックスを表示
abs(diff_median).argsort()[::-1]

【実行結果】
絶対値の大きい順に都道府県コード(1少ない値です)を並べました。

上位3県を表示します。

### 上位3県のY-rの値の表示 ◆テキスト253ページ ※テキストと大きく異なる
print(f'20 長野県: {diff_median[19]:5.2f}') # 20: 長野県
print(f'22 静岡県: {diff_median[21]:5.2f}') # 22: 静岡県
print(f'21 岐阜県: {diff_median[20]:5.2f}') # 21: 岐阜県

【実行結果】
テキストは 長野:-2.49℃、静岡:+1.34℃、東京:+1.20℃ですので、このPyMCモデルは第3位の都道府県が異なることと、気温の値が異なる(絶対値で小さい)ことが相違点です。

観測値と予測値のプロットを描画します。
テキスト図12.11左に相当します。

### 観測値と予測値のプロット ◆図12.11左

# サンプリングデータの2.5%,50%,97.5%パーセンタイル点を算出してデータフレーム化
y_pred_df = pd.DataFrame(
    np.quantile(r_samples, q=[0.025, 0.5, 0.975], axis=1).T,
    columns=['2.5%', 'median', '97.5%'],
    index=range(1, len(r_samples)+1))
y_pred_df = pd.concat([data1, y_pred_df], axis=1)
# 中央値と.25%点の差、97.5%点と中央値の差を算出: errorbarで利用
y_pred_df['err_lower'] = y_pred_df['median'] - y_pred_df['2.5%'] 
y_pred_df['err_upper'] = y_pred_df['97.5%']  - y_pred_df['median']

## 描画処理
# 描画領域の指定
plt.figure(figsize=(6, 6))
ax = plt.subplot()
# Yの観測値とrの予測値(中央値)のエラーバー付き散布図の描画
ax.errorbar(y_pred_df['Y'], y_pred_df['median'],
            yerr=[y_pred_df['err_lower'], y_pred_df['err_upper']],
            color='tab:blue', alpha=0.5, marker='o', ms=8, linestyle='none')
# その他の影響Y-rの絶対値の大きい3県の表示
top_dict = {19: '長野', 21: '静岡', 20: '岐阜'}
for k, v in top_dict.items():
    ax.text(s=f'{k+1}:{v}', x=y_pred_df.iloc[k, 0], y=y_pred_df.iloc[k, 2])
# 赤い対角線の描画
ax.plot([9, 24], [9, 24], color='red', ls='--')
# 修飾
ax.set(xlabel='Observed: $Y$の観測値', ylabel='Predicted: $r$の予測値(中央値)',
       title='$Y$ の観測値と $r$の予測値(中央値)のプロット\n95%区間バー付き')
ax.grid(lw=0.5);

【実行結果】
テキストよりも赤い点線に近くなっている=予測値が観測値と近い状況になりました。

推定されたノイズの分布を描画します。
テキスト図12.11右に相当します。
まずMAP推定値(のような値)などの描画に必要なデータを作成します。
なお、MAP推定値(のような値)は seaborn の KDEプロットで算出された値に基づいています。

### 推定されたノイズY-rの分布のプロット ◆図12.11右

## 描画用データの作成1:各都道府県のY-rのMAP推定値の算出
# MAP推定値を格納する一時リストの準備
map_plot = []
# 47都道府県ごとにMAP推定値らしき値をKDEプロットにて算出する処理を繰り返す
for i, diff in enumerate(diff_samples):
    # KDEプロットを仮想的に描画
    kde_plot = sns.kdeplot(diff)
    # KDEプロットで描画した線から、x軸,y軸の値を取得
    kde_data = kde_plot.get_lines()[i].get_data()
    # KDEプロットのyの最大値のインデックスを取得
    kde_max_idx = kde_data[1].argmax()
    # KDEプロットのyの最大値に対応するxの値=MAP推定値を一時リストに格納
    map_plot.append(kde_data[0][kde_max_idx])
# KDEプロットを表示しないで終了する
plt.close()
# Y-rデータをnumpy配列化
diff_map = np.array(map_plot)

## 描画用データの作成2:sigmaYのMAP推定値の算出
# 推論データからsigmaYのMCMCサンプルデータを取り出し
sigmaY_samples = idata.posterior.sigmaY.data.flatten()
# KDEプロットを仮想的に描画
sigmaY_plot = sns.kdeplot(sigmaY_samples)
# KDEプロットで描画した線から、x軸,y軸の値を取得
sigmaY_kde_data = sigmaY_plot.get_lines()[0].get_data()
# KDEプロットのyの最大値に対応するxの値=MAP推定値を算出
sigmaY_map = sigmaY_kde_data[0][sigmaY_kde_data[1].argmax()]
# KDEプロットを表示しないで終了する
plt.close()

## 描画用データの作成3:平均0, 標準偏差sigmaYの正規分布の確率密度関数の算出
# 正規分布のx軸の値を設定
xvals = np.linspace(-2.5, 2, 1001)
# 正規分布の確率密度関数の算出
yvals = stats.norm.pdf(xvals, loc=0, scale=sigmaY_map)

続いて描画します。

### 描画処理
# 描画領域の設定
fig, ax = plt.subplots()  # axにy-rのMAP推定値のヒストグラムを描画
twinx1 = ax.twinx()       # twinx1にy-rのMAP推定値のKDEプロットを描画
twinx2 = ax.twinx()       # twinx2に正規分布の確率密度関数を描画
# Y-rのMAP推定値のヒストグラムの描画
sns.histplot(map_plot, bins=15, color='tab:blue', ec='white', alpha=0.7,
             label='Y[n]-r[n]のMAP推定値', ax=ax)
ax.grid(lw=0.5)
# Y-rのMAP推定値のKDEプロットの描画
sns.kdeplot(np.array(map_plot), color='tab:blue', fill=True, alpha=0.2,
            ax=twinx1, label='Y[n]-r[n]のMAP推定値のKDE')
twinx1.set(yticks=[], ylabel='')
# sigmaYのMAP推定値を標準偏差とする正規分布の確率密度関数の描画
twinx2.plot(xvals, yvals, color='tab:red', ls='--',
            label='正規分布(0, $\sigma_Y$)')
twinx2.set(yticks=[], ylabel='', ylim=(0, 0.96))
# 修飾
fig.suptitle('Y[n]-r[n]のMAP推定値のヒストグラム・KDE,\n'
             '$\sigma_Y$ のMAP推定値を用いた正規分布の確率密度関数')
fig.supxlabel('value')
fig.legend(bbox_to_anchor=(1.3, 0.9));

【実行結果】
横軸-2付近のデータは長野です。
テキストに掲載の通り、その他の影響 sigmaY は長野の外れ値をうまく扱える「Studentのt分布」を用いてもいいのかもしれません。

おまけ

KDEの計算について、次のサイトの調査を参考にして、scipy.stats の gaussian_kde を動かしてみました。
ありがとうございます!

sigmaYのKDEを計算して、最大値をとるsigmaYの値≒MAP推定値を計算してみます。

### 参考:sigmaYのMAP推定値の算出~scipyのgaussian_kde利用~
# https://vaaaaaanquish.hatenablog.com/entry/2017/10/29/181949

# 変数xの値の設定
test_x = np.linspace(sigmaY_samples.min(), sigmaY_samples.max(), 10001)
# sigmaYのKDEモデルの設定
test_model = stats.gaussian_kde(sigmaY_samples)
# xの値からKDEの値を算出
test_y = test_model(test_x)
# KDEが最大となる変数xの値≒MAP推定値を算出
test_x_map = test_x[test_y.argmax()]
print(f'sigmaYのMAP推定値: {test_x_map:.3f}')

【実行結果】

KDE曲線を描画しましょう。

### sigaYのKDE曲線の描画
# KDE曲線の描画
plt.plot(test_x, test_y)
# MAP推定値の垂直線の描画
plt.vlines(x=test_x_map, ymin=0, ymax=test_y[test_y.argmax()], color='red',
           ls='--')
# 修飾
plt.title(f'sigmaYのMAP推定値: {test_x_map:.3f}')
plt.grid(lw=0.5);

【実行結果】

12.8 節は以上です。


シリーズの記事

次の記事

前の記事

目次

ブログの紹介


note で7つのシリーズ記事を書いています。
ぜひ覗いていってくださいね!

1.のんびり統計

統計検定2級の問題集を手がかりにして、確率・統計をざっくり掘り下げるブログです。
雑談感覚で大丈夫です。ぜひ覗いていってくださいね。
統計検定2級公式問題集CBT対応版に対応しています。
Python、EXCELのサンプルコードの配布もあります。

2.実験!たのしいベイズモデリング1&2をPyMC Ver.5で

書籍「たのしいベイズモデリング」・「たのしいベイズモデリング2」の心理学研究に用いられたベイズモデルを PyMC Ver.5で描いて分析します。
この書籍をはじめ、多くのベイズモデルは R言語+Stanで書かれています。
PyMCの可能性を探り出し、手軽にベイズモデリングを実践できるように努めます。
身近なテーマ、イメージしやすいテーマですので、ぜひぜひPyMCで動かして、一緒に楽しみましょう!

3.実験!岩波データサイエンス1のベイズモデリングをPyMC Ver.5で

書籍「実験!岩波データサイエンスvol.1」の4人のベイジアンによるベイズモデルを PyMC Ver.5で描いて分析します。
この書籍はベイズプログラミングのイロハをざっくりと学ぶことができる良書です。
楽しくPyMCモデルを動かして、ベイズと仲良しになれた気がします。
みなさんもぜひぜひPyMCで動かして、一緒に遊んで学びましょう!

4.楽しい写経 ベイズ・Python等

ベイズ、Python、その他の「書籍の写経活動」の成果をブログにします。
主にPythonへの翻訳に取り組んでいます。
写経に取り組むお仲間さんのサンプルコードになれば幸いです🍀

5.RとStanではじめる心理学のための時系列分析入門 を PythonとPyMC Ver.5 で

書籍「RとStanではじめる心理学のための時系列分析入門」の時系列分析をPythonとPyMC Ver.5 で実践します。
この書籍には時系列分析のテーマが盛りだくさん!
時系列分析の懐の深さを実感いたしました。
大好きなPythonで楽しく時系列分析を学びます。

6.データサイエンスっぽいことを綴る

統計、データ分析、AI、機械学習、Pythonのコラムを不定期に綴っています。
統計・データサイエンス書籍にまつわる記事が多いです。
「統計」「Python」「数学とPython」「R」のシリーズが生まれています。

7.Python機械学習プログラミング実践記

書籍「Python機械学習プログラミング PyTorch & scikit-learn編」を学んだときのさまざまな思いを記事にしました。
この書籍は、scikit-learnとPyTorchの教科書です。
よかったらぜひ、お試しくださいませ。

最後までお読みいただきまして、ありがとうございました。

いいなと思ったら応援しよう!