Created
March 25, 2026 16:00
-
-
Save statwonk/f4214a31caed49d3c1360899fec61a3f to your computer and use it in GitHub Desktop.
BHM vs BNN vs Ensemble: Out-of-Sample CRPS Comparison (R reprex with cmdstanr)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # ============================================================================ | |
| # BHM vs BNN vs Ensemble: Out-of-Sample CRPS Comparison | |
| # ============================================================================ | |
| # Three cases demonstrating the "no free lunch" principle: | |
| # Case 1: BHM wins when group structure is describable | |
| # Case 2: BNN wins when structure is complex/unstructured | |
| # Case 3: Their ensemble minimizes out-of-sample error across both DGPs | |
| # | |
| # Requirements: cmdstanr, posterior, scoringRules, dplyr, tibble, ggplot2 | |
| # ============================================================================ | |
| library(cmdstanr) | |
| library(posterior) | |
| library(scoringRules) | |
| library(dplyr) | |
| library(tibble) | |
| library(ggplot2) | |
| set.seed(42) | |
| # ── Stan model code ────────────────────────────────────────────────────────── | |
| # Bayesian hierarchical model: group-varying intercept + slope | |
| hier_stan_code <- " | |
| data { | |
| int<lower=1> N; | |
| int<lower=1> J; | |
| array[N] int<lower=1, upper=J> g; | |
| vector[N] x; | |
| vector[N] y; | |
| int<lower=1> N_test; | |
| array[N_test] int<lower=1, upper=J> g_test; | |
| vector[N_test] x_test; | |
| } | |
| parameters { | |
| real alpha_0; | |
| real beta_0; | |
| real<lower=0> sigma_alpha; | |
| real<lower=0> sigma_beta; | |
| vector[J] alpha_raw; | |
| vector[J] beta_raw; | |
| real<lower=0> sigma; | |
| } | |
| transformed parameters { | |
| vector[J] alpha = alpha_0 + sigma_alpha * alpha_raw; | |
| vector[J] beta = beta_0 + sigma_beta * beta_raw; | |
| } | |
| model { | |
| alpha_0 ~ normal(0, 2); | |
| beta_0 ~ normal(0, 2); | |
| sigma_alpha ~ normal(0, 1); | |
| sigma_beta ~ normal(0, 1); | |
| alpha_raw ~ std_normal(); | |
| beta_raw ~ std_normal(); | |
| sigma ~ normal(0, 1); | |
| vector[N] mu; | |
| for (n in 1:N) | |
| mu[n] = alpha[g[n]] + beta[g[n]] * x[n]; | |
| y ~ normal(mu, sigma); | |
| } | |
| generated quantities { | |
| vector[N_test] yrep_test; | |
| for (n in 1:N_test) { | |
| real mu_n = alpha[g_test[n]] + beta[g_test[n]] * x_test[n]; | |
| yrep_test[n] = normal_rng(mu_n, sigma); | |
| } | |
| } | |
| " | |
| # Bayesian neural network: 1 hidden layer, no group intercepts | |
| bnn_stan_code <- " | |
| data { | |
| int<lower=1> N; | |
| vector[N] x; | |
| vector[N] y; | |
| int<lower=1> N_test; | |
| vector[N_test] x_test; | |
| int<lower=1> H; | |
| } | |
| parameters { | |
| vector[H] w1; | |
| vector[H] b1; | |
| vector[H] w2; | |
| real b2; | |
| real<lower=0> sigma; | |
| } | |
| model { | |
| w1 ~ normal(0, 1); | |
| b1 ~ normal(0, 1); | |
| w2 ~ normal(0, 1); | |
| b2 ~ normal(0, 1); | |
| sigma ~ normal(0, 1); | |
| vector[N] mu; | |
| for (n in 1:N) { | |
| vector[H] h = tanh(w1 * x[n] + b1); | |
| mu[n] = b2 + dot_product(w2, h); | |
| } | |
| y ~ normal(mu, sigma); | |
| } | |
| generated quantities { | |
| vector[N_test] yrep_test; | |
| for (n in 1:N_test) { | |
| vector[H] h = tanh(w1 * x_test[n] + b1); | |
| yrep_test[n] = normal_rng(b2 + dot_product(w2, h), sigma); | |
| } | |
| } | |
| " | |
| # ── Compile Stan models once ──────────────────────────────────────────────── | |
| cat("Compiling Stan models...\n") | |
| hier_file <- write_stan_file(hier_stan_code) | |
| bnn_file <- write_stan_file(bnn_stan_code) | |
| mod_hier <- cmdstan_model(hier_file) | |
| mod_bnn <- cmdstan_model(bnn_file) | |
| # ── Helper: extract yrep matrix (obs × draws) ─────────────────────────────── | |
| extract_yrep <- function(fit) { | |
| mat <- fit$draws("yrep_test", format = "matrix") | |
| t(mat[, grepl("^yrep_test\\[", colnames(mat)), drop = FALSE]) | |
| } | |
| # ── Sampling helper ───────────────────────────────────────────────────────── | |
| fit_model <- function(mod, data, ...) { | |
| mod$sample( | |
| data = data, | |
| chains = 2, parallel_chains = 2, | |
| iter_warmup = 400, iter_sampling = 400, | |
| refresh = 0, show_messages = FALSE, | |
| ... | |
| ) | |
| } | |
| # ============================================================================ | |
| # CASE 1: Grouped linear DGP — BHM should win | |
| # ============================================================================ | |
| cat("\n══════════════════════════════════════════════════\n") | |
| cat("CASE 1: Grouped linear DGP (BHM advantage)\n") | |
| cat("══════════════════════════════════════════════════\n") | |
| J <- 10; n_per <- 30; N1 <- J * n_per | |
| alpha_true <- rnorm(J, 0, 1.5) | |
| beta_true <- rnorm(J, 2, 0.8) | |
| d1 <- tibble( | |
| g = rep(1:J, each = n_per), | |
| x = runif(N1, -2, 2) | |
| ) %>% | |
| mutate( | |
| mu = alpha_true[g] + beta_true[g] * x, | |
| y = rnorm(n(), mu, 0.5) | |
| ) %>% | |
| group_by(g) %>% | |
| mutate(is_train = row_number() <= 20) %>% | |
| ungroup() | |
| train1 <- filter(d1, is_train) | |
| test1 <- filter(d1, !is_train) | |
| # Fit BHM (matches the DGP) | |
| fit1_hier <- fit_model(mod_hier, list( | |
| N = nrow(train1), J = J, g = train1$g, x = train1$x, y = train1$y, | |
| N_test = nrow(test1), g_test = test1$g, x_test = test1$x | |
| )) | |
| # Fit BNN (no group structure — disadvantaged) | |
| fit1_bnn <- fit_model(mod_bnn, list( | |
| N = nrow(train1), x = train1$x, y = train1$y, | |
| N_test = nrow(test1), x_test = test1$x, H = 8 | |
| )) | |
| yrep1_hier <- extract_yrep(fit1_hier) | |
| yrep1_bnn <- extract_yrep(fit1_bnn) | |
| # Ensemble: pool posterior predictive draws 50/50 | |
| n_draws <- ncol(yrep1_hier) | |
| ens_idx <- sample(c(TRUE, FALSE), n_draws, replace = TRUE) | |
| yrep1_ens <- yrep1_hier | |
| yrep1_ens[, !ens_idx] <- yrep1_bnn[, !ens_idx] | |
| crps1_hier <- mean(crps_sample(test1$y, yrep1_hier)) | |
| crps1_bnn <- mean(crps_sample(test1$y, yrep1_bnn)) | |
| crps1_ens <- mean(crps_sample(test1$y, yrep1_ens)) | |
| cat(sprintf(" BHM mean CRPS: %.4f\n", crps1_hier)) | |
| cat(sprintf(" BNN mean CRPS: %.4f\n", crps1_bnn)) | |
| cat(sprintf(" ENS mean CRPS: %.4f\n", crps1_ens)) | |
| # ============================================================================ | |
| # CASE 2: Complex nonlinear DGP — BNN should win | |
| # ============================================================================ | |
| cat("\n══════════════════════════════════════════════════\n") | |
| cat("CASE 2: Complex nonlinear DGP (BNN advantage)\n") | |
| cat("══════════════════════════════════════════════════\n") | |
| N2 <- 300 | |
| d2 <- tibble( | |
| g = rep(1:J, length.out = N2), | |
| x = runif(N2, -3, 3) | |
| ) %>% | |
| mutate( | |
| mu = 2 * sin(1.5 * x) + 0.5 * cos(3 * x), | |
| y = rnorm(n(), mu, 0.4) | |
| ) %>% | |
| group_by(g) %>% | |
| mutate(is_train = row_number() <= floor(0.7 * n())) %>% | |
| ungroup() | |
| train2 <- filter(d2, is_train) | |
| test2 <- filter(d2, !is_train) | |
| # Fit BHM (linear — misspecified for nonlinear signal) | |
| fit2_hier <- fit_model(mod_hier, list( | |
| N = nrow(train2), J = J, g = train2$g, x = train2$x, y = train2$y, | |
| N_test = nrow(test2), g_test = test2$g, x_test = test2$x | |
| )) | |
| # Fit BNN (flexible nonlinear — matches DGP) | |
| fit2_bnn <- fit_model(mod_bnn, list( | |
| N = nrow(train2), x = train2$x, y = train2$y, | |
| N_test = nrow(test2), x_test = test2$x, H = 8 | |
| )) | |
| yrep2_hier <- extract_yrep(fit2_hier) | |
| yrep2_bnn <- extract_yrep(fit2_bnn) | |
| # Ensemble | |
| n_draws2 <- ncol(yrep2_hier) | |
| ens_idx2 <- sample(c(TRUE, FALSE), n_draws2, replace = TRUE) | |
| yrep2_ens <- yrep2_hier | |
| yrep2_ens[, !ens_idx2] <- yrep2_bnn[, !ens_idx2] | |
| crps2_hier <- mean(crps_sample(test2$y, yrep2_hier)) | |
| crps2_bnn <- mean(crps_sample(test2$y, yrep2_bnn)) | |
| crps2_ens <- mean(crps_sample(test2$y, yrep2_ens)) | |
| cat(sprintf(" BHM mean CRPS: %.4f\n", crps2_hier)) | |
| cat(sprintf(" BNN mean CRPS: %.4f\n", crps2_bnn)) | |
| cat(sprintf(" ENS mean CRPS: %.4f\n", crps2_ens)) | |
| # ============================================================================ | |
| # CASE 3: Summary — ensemble diversification | |
| # ============================================================================ | |
| cat("\n══════════════════════════════════════════════════\n") | |
| cat("SUMMARY: Ensemble diversification across DGPs\n") | |
| cat("══════════════════════════════════════════════════\n") | |
| results <- tibble( | |
| DGP = c(rep("Grouped linear", 3), rep("Complex nonlinear", 3)), | |
| Model = rep(c("BHM", "BNN", "Ensemble"), 2), | |
| CRPS = c(crps1_hier, crps1_bnn, crps1_ens, | |
| crps2_hier, crps2_bnn, crps2_ens) | |
| ) | |
| print(results) | |
| # ── Visualization ─────────────────────────────────────────────────────────── | |
| p <- ggplot(results, aes(x = Model, y = CRPS, fill = Model)) + | |
| geom_col(alpha = 0.85) + | |
| facet_wrap(~ DGP, scales = "free_y") + | |
| scale_fill_manual(values = c( | |
| "BHM" = "#0f3460", | |
| "BNN" = "#e94560", | |
| "Ensemble" = "#16213e" | |
| )) + | |
| labs( | |
| title = "Out-of-Sample CRPS: BHM vs BNN vs Ensemble", | |
| subtitle = "Lower CRPS = better probabilistic predictions", | |
| y = "Mean Test CRPS", | |
| x = NULL | |
| ) + | |
| theme_minimal(base_size = 14) + | |
| theme(legend.position = "none", | |
| strip.text = element_text(face = "bold")) | |
| ggsave("analysis/bhm_bnn_ensemble_crps.png", p, | |
| width = 9, height = 5, dpi = 150, bg = "white") | |
| cat("\nPlot saved to analysis/bhm_bnn_ensemble_crps.png\n") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment