機械学習:R言語(Tidymodels)チュートリアルの和訳② ブートストラッピング
第2弾!!
元の記事はこちら
https://www.tidymodels.org/learn/statistics/bootstrap/
Bootstrap resampling and tidy regression models
library(tidymodels)
ggplot(mtcars, aes(mpg, wt)) +
geom_point()
mtcarsデータセットに対して非線形モデルを適合させる方法を示します。
nlsfit <- nls(mpg ~ k / wt + b, mtcars, start = list(k = 1, b = 0))
summary(nlsfit)
#>
#> Formula: mpg ~ k/wt + b
#>
#> Parameters:
#> Estimate Std. Error t value Pr(>|t|)
#> k 45.829 4.249 10.786 7.64e-12 ***
#> b 4.386 1.536 2.855 0.00774 **
#> ---
#> Signif. codes: 0 '' 0.001 '' 0.01 '' 0.05 '.' 0.1 ' ' 1
#>
#> Residual standard error: 2.774 on 30 degrees of freedom
#>
#> Number of iterations to convergence: 1
#> Achieved convergence tolerance: 6.813e-09
ggplot(mtcars, aes(wt, mpg)) +
geom_point() +
geom_line(aes(y = predict(nlsfit)))
ブートストラッピングの導入
ブートストラッピングを使用することで、信頼区間と予測をより現実的に評価できます。以下に、ブートストラッピングを使用して非線形モデルを適合させる方法を示します。
set.seed(27)
boots <- bootstraps(mtcars, times = 2000, apparent = TRUE)
boots
#> # Bootstrap sampling with apparent sample
#> # A tibble: 2,001 × 2
#> splits id
#> <list> <chr>
#> 1 <split [32/13]> Bootstrap0001
#> 2 <split [32/10]> Bootstrap0002
#> 3 <split [32/13]> Bootstrap0003
#> 4 <split [32/11]> Bootstrap0004
#> 5 <split [32/9]> Bootstrap0005
#> 6 <split [32/10]> Bootstrap0006
#> 7 <split [32/11]> Bootstrap0007
#> 8 <split [32/13]> Bootstrap0008
#> 9 <split [32/11]> Bootstrap0009
#> 10 <split [32/11]> Bootstrap0010
#> # ℹ 1,991 more rows
各ブートストラップサンプルに対してnls()モデルを適合させるヘルパー関数を作成し、この関数をpurrr::map()を使って全てのブートストラップサンプルに一度に適用します。同様に、unnestを使って整然とした係数情報の列を作成します。
fit_nls_on_bootstrap <- function(split) {
nls(mpg ~ k / wt + b, analysis(split), start = list(k = 1, b = 0))
}
boot_models <-
boots %>%
mutate(model = map(splits, fit_nls_on_bootstrap),
coef_info = map(model, tidy))
boot_coefs <-
boot_models %>%
unnest(coef_info)
boot_coefs
#> # A tibble: 4,002 × 8
#> splits id model term estimate std.error statistic p.value
#> <list> <chr> <lis> <chr> <dbl> <dbl> <dbl> <dbl>
#> 1 <split [32/13]> Bootstrap0… <nls> k 42.1 4.05 10.4 1.91e-11
#> 2 <split [32/13]> Bootstrap0… <nls> b 5.39 1.43 3.78 6.93e- 4
#> 3 <split [32/10]> Bootstrap0… <nls> k 49.9 5.66 8.82 7.82e-10
#> 4 <split [32/10]> Bootstrap0… <nls> b 3.73 1.92 1.94 6.13e- 2
#> 5 <split [32/13]> Bootstrap0… <nls> k 37.8 2.68 14.1 9.01e-15
#> 6 <split [32/13]> Bootstrap0… <nls> b 6.73 1.17 5.75 2.78e- 6
#> 7 <split [32/11]> Bootstrap0… <nls> k 45.6 4.45 10.2 2.70e-11
#> 8 <split [32/11]> Bootstrap0… <nls> b 4.75 1.62 2.93 6.38e- 3
#> 9 <split [32/9]> Bootstrap0… <nls> k 43.6 4.63 9.41 1.85e-10
#> 10 <split [32/9]> Bootstrap0… <nls> b 5.89 1.68 3.51 1.44e- 3
#> # ℹ 3,992 more rows
信頼区間の計算
percentile_intervals <- int_pctl(boot_models, coef_info)
percentile_intervals
#> # A tibble: 2 × 6
#> term .lower .estimate .upper .alpha .method
#> <chr> <dbl> <dbl> <dbl> <dbl> <chr>
#> 1 b 0.0475 4.12 7.31 0.05 percentile
#> 2 k 37.6 46.7 59.8 0.05 percentile
ggplot(boot_coefs, aes(estimate)) +
geom_histogram(bins = 30) +
facet_wrap( ~ term, scales = "free") +
geom_vline(aes(xintercept = .lower), data = percentile_intervals, col = "blue") +
geom_vline(aes(xintercept = .upper), data = percentile_intervals, col = "blue")
モデルの適合例
augment()を使って、適合曲線の不確実性を視覚化できます。ブートストラップサンプルが非常に多いため、可視化にはモデル適合のサンプルのみを表示します。
boot_aug <-
boot_models %>%
sample_n(200) %>%
mutate(augmented = map(model, augment)) %>%
unnest(augmented)
boot_aug
#> # A tibble: 6,400 × 8
#> splits id model coef_info mpg wt .fitted .resid
#> <list> <chr> <list> <list> <dbl> <dbl> <dbl> <dbl>
#> 1 <split [32/11]> Bootstrap1644 <nls> <tibble> 16.4 4.07 15.6 0.829
#> 2 <split [32/11]> Bootstrap1644 <nls> <tibble> 19.7 2.77 21.9 -2.21
#> 3 <split [32/11]> Bootstrap1644 <nls> <tibble> 19.2 3.84 16.4 2.84
#> 4 <split [32/11]> Bootstrap1644 <nls> <tibble> 21.4 2.78 21.8 -0.437
#> 5 <split [32/11]> Bootstrap1644 <nls> <tibble> 26 2.14 27.8 -1.75
#> 6 <split [32/11]> Bootstrap1644 <nls> <tibble> 33.9 1.84 32.0 1.88
#> 7 <split [32/11]> Bootstrap1644 <nls> <tibble> 32.4 2.2 27.0 5.35
#> 8 <split [32/11]> Bootstrap1644 <nls> <tibble> 30.4 1.62 36.1 -5.70
#> 9 <split [32/11]> Bootstrap1644 <nls> <tibble> 21.5 2.46 24.4 -2.86
#> 10 <split [32/11]> Bootstrap1644 <nls> <tibble> 26 2.14 27.8 -1.75
#> # ℹ 6,390 more rows
ggplot(boot_aug, aes(wt, mpg)) +
geom_line(aes(y = .fitted, group = id), alpha = .2, col = "blue") +
geom_point()
わずかな変更を加えるだけで、他の種類の予測モデルや仮説検定モデルに対しても簡単にブートストラッピングを行うことができます。なぜなら、tidy()およびaugment()関数は多くの統計出力に対応しているからです。別の例として、smooth.spline()を使用することもできます。これはデータに対して三次平滑スプラインを適合させるものです。
fit_spline_on_bootstrap <- function(split) {
data <- analysis(split)
smooth.spline(data$wt, data$mpg, df = 4)
}
boot_splines <-
boots %>%
sample_n(200) %>%
mutate(spline = map(splits, fit_spline_on_bootstrap),
aug_train = map(spline, augment))
splines_aug <-
boot_splines %>%
unnest(aug_train)
ggplot(splines_aug, aes(x, y)) +
geom_line(aes(y = .fitted, group = id), alpha = 0.2, col = "blue") +
geom_point()
END
この記事が気に入ったらサポートをしてみませんか?