見出し画像

実験!岩波データサイエンス1のベイズモデリングをPyMC Ver.5で⑪最小二乗法のベイズ版

「Stan入門」章

テキスト「Stan入門」の執筆者
松浦健太郎 先生


この記事は、テキストの「Stan入門」章の例題1「最小二乗法のベイズ版」の実践を取り扱います。
線形回帰モデルです。記事第1回・5回に通じます!予測もあります!
そしてテキストの章タイトルはStan入門ですが、記事ではStanを用いません。
すべてPyMCに書き換えます!
たのしくPyMCモデリングを進めましょう!

この章を執筆された松浦先生は「StanとRでベイズ統計モデリング (Wonderful R 2)」を執筆されています。
こちらもお楽しみくださいませ。


はじめに


岩波データサイエンスVol.1の紹介

この記事は書籍「岩波データサイエンス vol.1」(岩波書店、以下「テキスト」と呼びます)の特集記事「ベイズ推論とMCMCのフリーソフト」のベイズモデルを用いて、PyMC Ver.5で「実験的」に実装する様子を描いた統計ドキュメンタリーです。

テキストは、2015年10月に発売され、ベイズモデリングの様々なソフトウェアを用いたモデリング事例を多数掲載し、ベイズモデリングの楽しさを紹介する素晴らしい書籍です。
入門的なモデルから2次階差を取り扱う空間モデルまで、幅広い難易度のモデルを満喫できます!

このシリーズは、テキストのベイズモデルをPyMC Ver.5に書き換えて実践します。

引用表記

この記事は、出典に記載の書籍に掲載された文章及びコードを引用し、適宜、掲載文章とコードを改変して書いています。
【出典】
「岩波データサイエンス vol.1」
第9刷、編者 岩波データサイエンス刊行委員会  岩波書店

記事中のイラストは、「かわいいフリー素材集いらすとや」さんのイラストをお借りしています。
ありがとうございます!

PyMC環境の準備

Anacondaを用いる環境構築とGoogle ColaboratoryでPyMCを動かす方法について、次の記事にまとめています。
「PyMCを動かすまでの準備」章をご覧ください。


PythonとPyMC
テキストで利用するツールは、R、Stan です。
この記事では、PythonとPyMCを用いたコードに変換してベイズモデリングを実践いたします。

イントロ


1. モデルのイメージ

■ 線形回帰モデル
テキストはベイズツール「Stan」の入門編ということで、次式の線形回帰モデルからスタートします。

$$
Y =a+b\ T+e, \quad e \sim \text{Normal}(0,\ \sigma)
$$

■ 分析データ
上記の線形回帰モデルを適用した乱数を生成して分析に用います。
パラメータは$${a=0.5,\ b=3,\ \sigma=1}$$です。

2. インポート

「Stan入門」章で利用するパッケージをインポートします。

### インポート

# ユーティリティ
import pickle

# 数値・確率計算
import pandas as pd
import numpy as np

# PyMC
import pymc as pm
import pytensor.tensor as pt
import arviz as az

# 最適化
from scipy.optimize import minimize_scalar

# 描画
import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['font.family'] = 'Meiryo'

# ワーニング表示の抑制
import warnings
warnings.simplefilter('ignore')

3. データの作成

乱数データを生成してpandasのデータフレームに読み込みます。
テキストの乱数生成結果と一致しないので、以後の分析結果もテキストと一致しないことを予めご了承くださいませ。

### データの作成 y = a + bT + e,  e ~ Normal(0, sd)

# 乱数生成器の作成
rng = np.random.default_rng(seed=123)

# 初期値設定
N = 20     # 標本サイズ
a = 0.5    # 真の切片
b = 3      # 真の傾き
sd = 1     # 真の標準偏差

# データの作成
T = np.linspace(0.1, 2.0, 20)                 # 説明変数:たぶん時間
Y = rng.normal(loc=a+b*T, scale=sd, size=N)   # 目的変数:線形モデルy=a+bT+e

# データフレーム化
data1 = pd.DataFrame({'T': T, 'Y': Y})
display(data1)

【実行結果】
説明変数$${T}$$と目的変数$${Y}$$、全20行のデータです。

予測に用いる説明変数$${T_{pred}}$$を51個、等間隔で作成します。

### 予測用のTの作成
# 0~2.5の区間を0.05刻みでデータ作成
T_pred = np.linspace(0, 2.5, 51)
# データの表示
print('T_pred.shape: ', T_pred.shape)
print(T_pred)

【実行結果】

4. データの外観の確認

seaborn の lmplot() で散布図+回帰直線を描画します。

### データの可視化
sns.lmplot(data=data1, x='T', y='Y', height=3, aspect=1.5,
           line_kws={'color': 'tomato'}, scatter_kws={'alpha': 0.5})
plt.title(f'相関係数: {data1.corr().iloc[0, 1]:.3f}')
plt.grid(lw=0.5);

【実行結果】
$${T}$$と$${Y}$$の間には強い正の相関があるようです。

では、PyMCモデリングに進みます。

ベイズモデル


1. モデルの定義

尤度関数に線形回帰モデル$${Y \sim \text{Normal}(a+b\ T,\ \sigma)}$$を定義します。
また、最後の計算値 Y_pred で目的変数 Y の事後予測を試みます。

### モデルの定義
with pm.Model() as model1:
    
    ### データ関連定義
    ## coordの定義
    # 観測データのインデックス
    model1.add_coord('data', values=data1.index, mutable=True)
    # 予測データのインデックス
    model1.add_coord('dataPred', values=np.arange(len(T_pred)), mutable=True)
    
    ## dataの定義
    # 観測値
    y = pm.ConstantData('y', value=data1['Y'].values, dims='data')
    # 説明変数
    T = pm.ConstantData('T', value=data1['T'].values, dims='data')
    # 説明変数(予測用)
    Tpred = pm.ConstantData('Tpred', value=T_pred, dims='dataPred')

    ### 事前分布
    a = pm.Normal('a', mu=0, sigma=100)
    b = pm.Normal('b', mu=0, sigma=100)
    sigma = pm.Uniform('sigma', lower=0, upper=1000)
    
    ### 尤度
    Y = pm.Normal('Y', mu=a+b*T, sigma=sigma, observed=y, dims='data')
    
    ### 計算値
    # Yの予測値
    Ypred = pm.Normal('Ypred', mu=a+b*Tpred, sigma=sigma, dims='dataPred')

モデルの内容を表示・可視化してみましょう。

### モデルの表示
model1

【実行結果】
線形回帰モデルの切片$${a}$$と傾き$${b}$$は正規分布(平均0、標準偏差100)を仮定しています。2つのパラメータに正規分布の縛りを入れているようです。

続いてモデルを可視化します。

### モデルの可視化
pm.model_to_graphviz(model1)

【実行結果】
左の縦系列は観測値にかかる部分、右の縦系列は事後予測にかかる部分です。
両方でパラメータ$${a,\ b,\ \sigma}$$と関連しています。

3. MCMCの実行と収束の確認

■ MCMC
マルコフ連鎖=chainsを4本、バーンイン期間=tuneを1000、利用するサンプル=drawを1chainあたり1000に設定して、合計4000個の事後分布からのサンプリングデータを生成します。テキストの指定数とは異なっています。
サンプル方法に numpyro を指定しています。処理速度が早くなります。
numpyro を使わない場合は「nuts_sampler='numpyro',」を削除します。

### 事後分布からのサンプリング 20秒
# テキスト:iter=1000, warmup=200, chains=3, thin=2
with model1:
    idata1 = pm.sample(draws=1000, tune=1000, chains=4, random_seed=1234,
                       nuts_sampler='numpyro')

【実行結果】
処理時間は20秒です。

■ $${\hat{R}}$$で収束確認
収束の確認をします。
指標$${\hat{R}}$$(あーるはっと)の値が$${1.1}$$以下のとき、収束したこととします。
次のコードでは$${\hat{R}>1.01}$$のパラメータの個数をカウントします。
個数が0ならば、$${\hat{R} \leq1.1}$$なので収束したとみなします。

### r_hat>1.1の確認
# 設定
idata_in = idata1        # idata名
threshold = 1.01         # しきい値

# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())

【実行結果】
$${\hat{R}>1.01}$$のパラメータは0個でした。
$${\hat{R} \leq1.1}$$なので収束したとみなします。

■ $${\hat{R}}$$の値把握と事後統計量の表示

### 事後統計量の表示
pm.summary(idata1, hdi_prob=0.95)

【実行結果】

■ トレースプロットで収束確認
トレースプロットを描画します。
こちらの図でも収束の確認を行えます。

### トレースプロットの描画
pm.plot_trace(idata1, var_names=var_names)
plt.tight_layout();

【実行結果】
左プロットの線の重なり具合、右プロットの偏りのない乱雑さより、収束している感じがします。

4. 分析

パラメータ$${a,\ b,\ \sigma}$$の事後分布を確認します。

### パラメータa,b,sigmaの事後統計量の表示
var_names=['a', 'b', 'sigma']
pm.summary(idata1, hdi_prob=0.95, var_names=var_names, kind='stats')

【実行結果】
事後平均値meanや95%HDI区間値を見ると、分析データ作成時に用いたパラメータ値$${a=0.5,\ b=3,\ \sigma=1}$$をよく推定できているように感じます。

3つのパラメータの事後分布を可視化しましょう。

### パラメータa,b,sigmaの事後分布の可視化
pm.plot_posterior(idata1, hdi_prob=0.95, var_names=var_names)
plt.tight_layout();

【実行結果】

$${Y}$$の事後予測 $${Y_{pred}}$$をプロットしましょう。
事後予測には上の3つのパラメータによる正規分布を用いています。
どんな予測値になっているか楽しみです。
テキストの図1に相当します。ただし、区間推定にはテキストの信頼区間(分位数に基づく)を使わず、HDIを用いました(plot_hdi()を利用)。

### 推定されたベイズ予測区間の描画 ※図1に相当

## 描画用データの取得
# 推論データからyの予測値のサンプリングデータを取り出し
y_pred_samples = idata1.posterior['Ypred'].stack(sample=('chain', 'draw')).data
# yの予測値の中央値を算出
y_pred_means = np.median(y_pred_samples, axis=1)

## 描画
# 描画領域の指定
plt.figure(figsize=(6, 3))
ax = plt.subplot()
# y予測値の中央値の折れ線グラフの描画
ax.plot(T_pred, y_pred_means, color='tab:red', label='中央値')
# y予測値の50%HDIの描画
az.plot_hdi(ax=ax, x=T_pred, y=y_pred_samples.T, hdi_prob=0.50, color='tomato',
            fill_kwargs={'alpha': 0.5, 'label': '50%HDI'})
# y予測値の95%HDIの描画
az.plot_hdi(ax=ax, x=T_pred, y=y_pred_samples.T, hdi_prob=0.95, color='tomato',
            fill_kwargs={'alpha': 0.2, 'label': '95%HDI'})
# y観測値の散布図の描画
ax.scatter(data1['T'], data1['Y'], color='tab:blue', label='観測値')
# 修飾
ax.set(xlabel='T', ylabel='Y', title='Yの事後予測プロット')
ax.grid(lw=0.5)
ax.legend();

【実行結果】
95%HDI区間に観測値(青点)が含まれる結果になりました。

(参考:単回帰直線)

5. 推論データの保存

pickle で MCMCサンプリングデータ idata1 を保存できます。

### 推論データの保存
file = r'idata1_ch4.pkl'
with open(file, 'wb') as f:
    pickle.dump(idata1, f)

次のコードで保存したファイルを読み込みできます。

### 推論データの読み込み
file = r'idata1_ch4.pkl'
with open(file, 'rb') as f:
    idata1_ch4_load = pickle.load(f)

ベイズ記事は以上です。

次回予告

「Stan入門」章
 例題2「状態空間モデルとベイズ決定」

結び


実はStanに興味がありまして・・・。
書籍やWebサイトで公開されている豊富な事例、モデリングの柔軟の高さ(反面、難しいところもありそう)に惹かれています。
将来の宿題ネタとして寝かしておこうと思います。

おわり


シリーズの記事

次の記事

前の記事

目次

ブログの紹介


note で7つのシリーズ記事を書いています。
ぜひ覗いていってくださいね!

1.のんびり統計

統計検定2級の問題集を手がかりにして、確率・統計をざっくり掘り下げるブログです。
雑談感覚で大丈夫です。ぜひ覗いていってくださいね。
統計検定2級公式問題集CBT対応版に対応しています。
Python、EXCELのサンプルコードの配布もあります。

2.実験!たのしいベイズモデリング1&2をPyMC Ver.5で

書籍「たのしいベイズモデリング」・「たのしいベイズモデリング2」の心理学研究に用いられたベイズモデルを PyMC Ver.5で描いて分析します。
この書籍をはじめ、多くのベイズモデルは R言語+Stanで書かれています。
PyMCの可能性を探り出し、手軽にベイズモデリングを実践できるように努めます。
身近なテーマ、イメージしやすいテーマですので、ぜひぜひPyMCで動かして、一緒に楽しみましょう!

3.実験!岩波データサイエンス1のベイズモデリングをPyMC Ver.5で

書籍「実験!岩波データサイエンスvol.1」の4人のベイジアンによるベイズモデルを PyMC Ver.5で描いて分析します。
この書籍はベイズプログラミングのイロハをざっくりと学ぶことができる良書です。
楽しくPyMCモデルを動かして、ベイズと仲良しになれた気がします。
みなさんもぜひぜひPyMCで動かして、一緒に遊んで学びましょう!

4.楽しい写経 ベイズ・Python等

ベイズ、Python、その他の「書籍の写経活動」の成果をブログにします。
主にPythonへの翻訳に取り組んでいます。
写経に取り組むお仲間さんのサンプルコードになれば幸いです🍀

5.RとStanではじめる心理学のための時系列分析入門 を PythonとPyMC Ver.5 で

書籍「RとStanではじめる心理学のための時系列分析入門」の時系列分析をPythonとPyMC Ver.5 で実践します。
この書籍には時系列分析のテーマが盛りだくさん!
時系列分析の懐の深さを実感いたしました。
大好きなPythonで楽しく時系列分析を学びます。

6.データサイエンスっぽいことを綴る

統計、データ分析、AI、機械学習、Pythonのコラムを不定期に綴っています。
統計・データサイエンス書籍にまつわる記事が多いです。
「統計」「Python」「数学とPython」「R」のシリーズが生まれています。

7.Python機械学習プログラミング実践記

書籍「Python機械学習プログラミング PyTorch & scikit-learn編」を学んだときのさまざまな思いを記事にしました。
この書籍は、scikit-learnとPyTorchの教科書です。
よかったらぜひ、お試しくださいませ。

最後までお読みいただきまして、ありがとうございました。

いいなと思ったら応援しよう!

この記事が参加している募集