見出し画像

StanとRでベイズ統計モデリングをPyMC Ver.5で写経~第7章「7.5 交絡」

第7章「回帰分析の悩みどころ」

書籍の著者 松浦健太郎 先生


この記事は、テキスト第7章「回帰分析の悩みどころ」の7.5節「交絡」の PyMC5写経 を取り扱います。
テキストを引用しますと、交絡とは「モデルの外側に応答変数と説明変数の両方に影響を与える変数が存在すること」です。
交絡変数を取り入れるモデルに取り組みます。

はじめに


StanとRでベイズ統計モデリングの紹介

この記事は書籍「StanとRでベイズ統計モデリング」(共立出版、「テキスト」と呼びます)のベイズモデルを用いて、PyMC Ver.5で「実験的」に写経する翻訳的ドキュメンタリーです。

テキストは、2016年10月に発売され、ベイズモデリングのモデル式とプログラミングに関する丁寧な解説とモデリングの改善ポイントを網羅するチュートリアル「実践解説書」です。もちろん素晴らしいです!
アヒル本」の愛称で多くのベイジアンに愛されてきた書籍です!

テキストに従ってStanとRで実践する予定でしたが、RのStan環境を整えることができませんでした(泣)
そこでこのシリーズは、テキストのベイズモデルをPyMC Ver.5に書き換えて実践します。

引用表記

この記事は、出典に記載の書籍に掲載された文章及びコードを引用し、適宜、掲載文章とコードを改変して書いています。
【出典】
「StanとRでベイズ統計モデリング」初版第13刷、著者 松浦健太郎、共立出版

記事中のイラストは、「かわいいフリー素材集いらすとや」さんのイラストをお借りしています。
ありがとうございます!

PyMC環境の準備

Anacondaを用いる環境構築とGoogle ColaboratoryでPyMCを動かす方法について、次の記事にまとめています。
「PyMCを動かすまでの準備」章をご覧ください。


7.5 交絡


インポート

### インポート

# 数値・確率計算
import pandas as pd
import numpy as np
import scipy.stats as stats

# PyMC
import pymc as pm
import arviz as az

# 描画
import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['font.family'] = 'Meiryo'

# ワーニング表示の抑制
import warnings
warnings.simplefilter('ignore')

データの読み込み

サンプルコードのデータを読み込みます。

### データの読み込み ◆data-50m.txt
# Y:50m走の平均秒速(m/秒), Weight:体重(kg), Age:年齢

data = pd.read_csv('./data/data-50m.txt')
print('data.shape: ', data.shape)
display(data.head())

【実行結果】

データの要約統計量と相関係数を算出します。

### 要約統計量の表示
data.describe().round(2)

【実行結果】

### 相関係数の表示
data.corr().round(3)

【実行結果】
体重$${Weight}$$と年齢$${Age}$$の相関係数は$${0.892}$$で、強い正の相関関係があります。

散布図を描画します。
テキスト図7.8に相当します。

### 散布図の描画 ◆図7.8

# 描画領域の設定
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
# 左の散布図の描画
sns.scatterplot(ax=ax1, data=data, x='Weight', y='Y', s=100, alpha=0.5)
ax1.set_title('体重と平均秒速の散布図')
ax1.grid(lw=0.5)
# 右の散布図の描画
sns.scatterplot(ax=ax2, data=data, x='Weight', y='Y', hue='Age',
                palette='tab10', s=100, alpha=0.5)
ax2.set_title('体重と平均秒速の散布図:年齢別')
ax2.grid(lw=0.5)
# 全体修飾
ax2.legend(bbox_to_anchor=(1, 1), title='$Age$')
plt.tight_layout()
plt.show()

【実行結果】
年齢と平均秒速との間には正の相関関係が見られます。
右のグラフより、体重と年齢、年齢と速さの関係が明らかになります。

テキスト図7.9をお借りして、3つの変数の関係を表現しました。

モデルの構築

Y の予測に用いる X の値を設定します。

### Yの予測分布に用いるXの値の設定
X_new = np.linspace(-3, 32, 60)

モデルの定義です。

### モデルの定義 ◆モデル式7-5

with pm.Model() as model:
    
    ### データ関連定義
    ## coordの定義
    model.add_coord('data', values=data.index, mutable=True)
    ## dataの定義
    # 目的変数 Y
    Y = pm.ConstantData('Y', value=data['Y'].values, dims='data')
    # 説明変数 Weight
    Weight = pm.ConstantData('Weight', value=data['Weight'].values, dims='data')
    # 説明変数 Age
    Age = pm.ConstantData('Age', value=data['Age'].values, dims='data')

    ### 事前分布
    b1 = pm.Uniform('b1', lower=-100, upper=100)
    b2 = pm.Uniform('b2', lower=-100, upper=100)
    b3 = pm.Uniform('b3', lower=-100, upper=100)
    c1 = pm.Uniform('c1', lower=-100, upper=100)
    c2 = pm.Uniform('c2', lower=-100, upper=100)
    sigmaW = pm.Uniform('sigmaW', lower=0, upper=100)
    sigmaY = pm.Uniform('sigmaY', lower=0, upper=100)

    ### Weight
    muW = pm.Deterministic('muW', c1 + c2 * Age, dims='data')
    obsW = pm.Normal('obsW', mu=muW, sigma=sigmaW, observed=Weight, dims='data')

    ### 尤度関数
    muY = pm.Deterministic('muY', b1 + b2 * Age + b3 * Weight, dims='data')
    obsY = pm.Normal('obsY', mu=muY, sigma=sigmaY, observed=Y, dims='data')

モデルの定義内容を見ます。

### モデルの表示
model

【実行結果】

### モデルの可視化
pm.model_to_graphviz(model)

【実行結果】

MCMCを実行します。

### 事後分布からのサンプリング 20秒
with model:
    idata  = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.8,
                       nuts_sampler='numpyro', random_seed=1234)

【実行結果】省略

Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。

### r_hat>1.1の確認
# 設定
idata_in = idata         # idata名
threshold = 1.01         # しきい値

# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())

【実行結果】
収束条件を満たしています。

事後統計量を表示します。

### 推論データの要約統計情報の表示
var_names = ['b1', 'b2', 'b3', 'c1', 'c2', 'sigmaW', 'sigmaY']
pm.summary(idata, hdi_prob=0.95, var_names=var_names, round_to=3)

【実行結果】

### 推論データの要約統計情報の表示
var_names = ['muW', 'muY']
pm.summary(idata, hdi_prob=0.95, var_names=var_names, round_to=3)

【実行結果】

トレースプロットを描画します。

### トレースプロットの表示
var_names = ['b1', 'b2', 'b3', 'c1', 'c2', 'sigmaW', 'sigmaY', 'muW', 'muY']
pm.plot_trace(idata, compact=True, var_names=var_names)
plt.tight_layout();

【実行結果】

パラメータの事後統計量の要約を算出します。

### パラメータの要約を確認

# mean,sd,2.5%,25%,50%,75%,97.5%パーセンタイル点をデータフレーム化する関数の定義
def make_stats_df(y):
    probs = [2.5, 25, 50, 75, 97.5]
    columns = ['mean', 'sd'] + [str(s) + '%' for s in probs]
    quantiles = pd.DataFrame(np.percentile(y, probs, axis=0).T, index=y.columns)
    tmp_df = pd.concat([y.mean(axis=0), y.std(axis=0), quantiles], axis=1)
    tmp_df.columns=columns
    return tmp_df

# 要約統計量の算出・表示
vars = ['b1', 'b2', 'b3', 'c1', 'c2', 'sigmaW', 'sigmaY']
param_samples = idata.posterior[vars].to_dataframe().reset_index(drop=True)
display(make_stats_df(param_samples).round(2))

【実行結果】

事後分布プロットを描画します。

### 事後分布プロットの描画
var_names = ['b1', 'b2', 'b3']
pm.plot_posterior(idata, hdi_prob=0.95, var_names=var_names, round_to=3,
                  figsize=(10, 3))
plt.tight_layout();

【実行結果】

### 事後分布プロットの描画
var_names = ['c1', 'c2']
pm.plot_posterior(idata, hdi_prob=0.95, var_names=var_names, round_to=3,
                  figsize=(6, 2.8))
plt.tight_layout();

【実行結果】

### 事後分布プロットの描画
var_names = ['sigmaW', 'sigmaY']
pm.plot_posterior(idata, hdi_prob=0.95, var_names=var_names, round_to=3,
                  figsize=(6, 2.8))
plt.tight_layout();

【実行結果】

フォレストプロットを描画します。

### フォレストプロットの描画
var_names = ['b1', 'b2', 'b3', 'c1', 'c2', 'sigmaW', 'sigmaY']
pm.plot_forest(idata, combined=True, hdi_prob=0.95, var_names=var_names,
               figsize=(5, 3))
plt.axvline(0, color='tab:red', ls='--')
plt.grid(lw=0.3);

【実行結果】

事後予測サンプリングを実行して$${Y}$$の予測値を描画します。

### 事後予測サンプリングデータの作成
with model:
    idata.extend(pm.sample_posterior_predictive(idata, random_seed=1234))

【実行結果】

### 事後予測プロットの描画
pm.plot_ppc(idata, num_pp_samples=100);

【実行結果】

$${Y}$$の観測値と予測値のプロットを描画します。

### Yの観測値と予測値のプロット

## 描画用データの作成 yPredの中央値と80%区間を算出
# MCMCサンプリングデータからyPredを取り出し
y_pred_samples = (idata.posterior_predictive.obsY
                  .stack(sample=('chain', 'draw')).data)
# サンプリングデータの10%,50%,90%パーセンタイル点を算出してデータフレーム化
y_pred_df = pd.DataFrame(
    np.quantile(y_pred_samples, q=[0.1, 0.5, 0.9], axis=1).T,
    columns=['10%', 'median', '90%'])
y_pred_df = pd.concat([data, y_pred_df], axis=1)
# 中央値と10%点の差、90%点と中央値の差を算出: errorbarで利用
y_pred_df['err_lower'] = y_pred_df['median'] - y_pred_df['10%'] 
y_pred_df['err_upper'] = y_pred_df['90%'] - y_pred_df['median']

## 描画処理
# 描画領域の指定
plt.figure(figsize=(6, 6))
ax = plt.subplot()
# 描画(エラーバー付き散布図)
ax.errorbar(y_pred_df['Y'], y_pred_df['median'],
            yerr=[y_pred_df['err_lower'], y_pred_df['err_upper']],
            color='tab:blue', alpha=0.5, marker='o', ms=10, linestyle='none')
# 赤い対角線の描画
ax.plot([2.3, 5.7], [2.3, 5.7], color='red', ls='--')
# 修飾
ax.set(xlabel='Observed: $Y$の観測値', ylabel='Predicted: $Y$の予測値(中央値)',
       title='$Y$ の観測値と予測値(中央値)のプロット\n80%区間, 通常スケール')
ax.grid(lw=0.5);

【実行結果】

MCMCサンプルの散布図行列を描画します。

### MCMCサンプルの散布図行列の描画

## 描画用データの作成
# MCMCサンプリングデータからmu1, mu66を取り出し
mu1_samples = (idata.posterior['muY'].to_dataframe().reset_index()
              .query('data==0').rename({'muY': 'mu1'}, axis=1))
mu66_samples = (idata.posterior['muY'].to_dataframe().reset_index()
               .query('data==65').rename({'muY': 'mu66'}, axis=1))
# 描画対象パラメータをデータフレーム化
plot_df = pd.concat([param_samples,
                     mu1_samples.reset_index(drop=True)['mu1'],
                     mu66_samples.reset_index(drop=True)['mu66']], axis=1)

## 描画処理
# 相関行列プロットの描画
g = sns.pairplot(plot_df, diag_kws={'kde': True, 'ec': 'white'})
# スピアマンの順位相関係数の表示のためのaxフラット化
ax = g.axes.ravel()

## スピアマンの順位相関係数を上三角のaxesに表示
# 列名をリスト化
cols = plot_df.columns
# 列名の組み合わせ行i, 列j ごとにテキスト表示を繰り返す
for i, col1 in enumerate(cols):
    for j, col2 in enumerate(cols):
        # 上三角の位置は 行i < 列j のとき
        if i < j:
            # axesの番号を取得
            pos = i * len(cols) + j
            # スピアマンの順位相関係数を算出
            corr, pval = stats.spearmanr(plot_df[col1], plot_df[col2])
            # テキスト表示:中央表示に関連する引数: x,y,va,ha,transform
            ax[pos].text(x=0.5, y=0.5, s=round(corr * 100), fontsize=30,
                         va='center', ha='center', transform=ax[pos].transAxes,
                         bbox=dict(boxstyle='round', facecolor='white'))

【実行結果】

7.5 節は以上です。


シリーズの記事

次の記事

前の記事

目次

ブログの紹介


note で7つのシリーズ記事を書いています。
ぜひ覗いていってくださいね!

1.のんびり統計

統計検定2級の問題集を手がかりにして、確率・統計をざっくり掘り下げるブログです。
雑談感覚で大丈夫です。ぜひ覗いていってくださいね。
統計検定2級公式問題集CBT対応版に対応しています。
Python、EXCELのサンプルコードの配布もあります。

2.実験!たのしいベイズモデリング1&2をPyMC Ver.5で

書籍「たのしいベイズモデリング」・「たのしいベイズモデリング2」の心理学研究に用いられたベイズモデルを PyMC Ver.5で描いて分析します。
この書籍をはじめ、多くのベイズモデルは R言語+Stanで書かれています。
PyMCの可能性を探り出し、手軽にベイズモデリングを実践できるように努めます。
身近なテーマ、イメージしやすいテーマですので、ぜひぜひPyMCで動かして、一緒に楽しみましょう!

3.実験!岩波データサイエンス1のベイズモデリングをPyMC Ver.5で

書籍「実験!岩波データサイエンスvol.1」の4人のベイジアンによるベイズモデルを PyMC Ver.5で描いて分析します。
この書籍はベイズプログラミングのイロハをざっくりと学ぶことができる良書です。
楽しくPyMCモデルを動かして、ベイズと仲良しになれた気がします。
みなさんもぜひぜひPyMCで動かして、一緒に遊んで学びましょう!

4.楽しい写経 ベイズ・Python等

ベイズ、Python、その他の「書籍の写経活動」の成果をブログにします。
主にPythonへの翻訳に取り組んでいます。
写経に取り組むお仲間さんのサンプルコードになれば幸いです🍀

5.RとStanではじめる心理学のための時系列分析入門 を PythonとPyMC Ver.5 で

書籍「RとStanではじめる心理学のための時系列分析入門」の時系列分析をPythonとPyMC Ver.5 で実践します。
この書籍には時系列分析のテーマが盛りだくさん!
時系列分析の懐の深さを実感いたしました。
大好きなPythonで楽しく時系列分析を学びます。

6.データサイエンスっぽいことを綴る

統計、データ分析、AI、機械学習、Pythonのコラムを不定期に綴っています。
統計・データサイエンス書籍にまつわる記事が多いです。
「統計」「Python」「数学とPython」「R」のシリーズが生まれています。

7.Python機械学習プログラミング実践記

書籍「Python機械学習プログラミング PyTorch & scikit-learn編」を学んだときのさまざまな思いを記事にしました。
この書籍は、scikit-learnとPyTorchの教科書です。
よかったらぜひ、お試しくださいませ。

最後までお読みいただきまして、ありがとうございました。

いいなと思ったら応援しよう!

この記事が参加している募集