Skip to content

Instantly share code, notes, and snippets.

@statwonk
Created March 25, 2026 16:00
Show Gist options
  • Select an option

  • Save statwonk/f4214a31caed49d3c1360899fec61a3f to your computer and use it in GitHub Desktop.

Select an option

Save statwonk/f4214a31caed49d3c1360899fec61a3f to your computer and use it in GitHub Desktop.
BHM vs BNN vs Ensemble: Out-of-Sample CRPS Comparison (R reprex with cmdstanr)
# ============================================================================
# BHM vs BNN vs Ensemble: Out-of-Sample CRPS Comparison
# ============================================================================
# Three cases demonstrating the "no free lunch" principle:
# Case 1: BHM wins when group structure is describable
# Case 2: BNN wins when structure is complex/unstructured
# Case 3: Their ensemble minimizes out-of-sample error across both DGPs
#
# Requirements: cmdstanr, posterior, scoringRules, dplyr, tibble, ggplot2
# ============================================================================
library(cmdstanr)
library(posterior)
library(scoringRules)
library(dplyr)
library(tibble)
library(ggplot2)
set.seed(42)
# ── Stan model code ──────────────────────────────────────────────────────────
# Bayesian hierarchical model: group-varying intercept + slope
hier_stan_code <- "
data {
int<lower=1> N;
int<lower=1> J;
array[N] int<lower=1, upper=J> g;
vector[N] x;
vector[N] y;
int<lower=1> N_test;
array[N_test] int<lower=1, upper=J> g_test;
vector[N_test] x_test;
}
parameters {
real alpha_0;
real beta_0;
real<lower=0> sigma_alpha;
real<lower=0> sigma_beta;
vector[J] alpha_raw;
vector[J] beta_raw;
real<lower=0> sigma;
}
transformed parameters {
vector[J] alpha = alpha_0 + sigma_alpha * alpha_raw;
vector[J] beta = beta_0 + sigma_beta * beta_raw;
}
model {
alpha_0 ~ normal(0, 2);
beta_0 ~ normal(0, 2);
sigma_alpha ~ normal(0, 1);
sigma_beta ~ normal(0, 1);
alpha_raw ~ std_normal();
beta_raw ~ std_normal();
sigma ~ normal(0, 1);
vector[N] mu;
for (n in 1:N)
mu[n] = alpha[g[n]] + beta[g[n]] * x[n];
y ~ normal(mu, sigma);
}
generated quantities {
vector[N_test] yrep_test;
for (n in 1:N_test) {
real mu_n = alpha[g_test[n]] + beta[g_test[n]] * x_test[n];
yrep_test[n] = normal_rng(mu_n, sigma);
}
}
"
# Bayesian neural network: 1 hidden layer, no group intercepts
bnn_stan_code <- "
data {
int<lower=1> N;
vector[N] x;
vector[N] y;
int<lower=1> N_test;
vector[N_test] x_test;
int<lower=1> H;
}
parameters {
vector[H] w1;
vector[H] b1;
vector[H] w2;
real b2;
real<lower=0> sigma;
}
model {
w1 ~ normal(0, 1);
b1 ~ normal(0, 1);
w2 ~ normal(0, 1);
b2 ~ normal(0, 1);
sigma ~ normal(0, 1);
vector[N] mu;
for (n in 1:N) {
vector[H] h = tanh(w1 * x[n] + b1);
mu[n] = b2 + dot_product(w2, h);
}
y ~ normal(mu, sigma);
}
generated quantities {
vector[N_test] yrep_test;
for (n in 1:N_test) {
vector[H] h = tanh(w1 * x_test[n] + b1);
yrep_test[n] = normal_rng(b2 + dot_product(w2, h), sigma);
}
}
"
# ── Compile Stan models once ────────────────────────────────────────────────
cat("Compiling Stan models...\n")
hier_file <- write_stan_file(hier_stan_code)
bnn_file <- write_stan_file(bnn_stan_code)
mod_hier <- cmdstan_model(hier_file)
mod_bnn <- cmdstan_model(bnn_file)
# ── Helper: extract yrep matrix (obs × draws) ───────────────────────────────
extract_yrep <- function(fit) {
mat <- fit$draws("yrep_test", format = "matrix")
t(mat[, grepl("^yrep_test\\[", colnames(mat)), drop = FALSE])
}
# ── Sampling helper ─────────────────────────────────────────────────────────
fit_model <- function(mod, data, ...) {
mod$sample(
data = data,
chains = 2, parallel_chains = 2,
iter_warmup = 400, iter_sampling = 400,
refresh = 0, show_messages = FALSE,
...
)
}
# ============================================================================
# CASE 1: Grouped linear DGP — BHM should win
# ============================================================================
cat("\n══════════════════════════════════════════════════\n")
cat("CASE 1: Grouped linear DGP (BHM advantage)\n")
cat("══════════════════════════════════════════════════\n")
J <- 10; n_per <- 30; N1 <- J * n_per
alpha_true <- rnorm(J, 0, 1.5)
beta_true <- rnorm(J, 2, 0.8)
d1 <- tibble(
g = rep(1:J, each = n_per),
x = runif(N1, -2, 2)
) %>%
mutate(
mu = alpha_true[g] + beta_true[g] * x,
y = rnorm(n(), mu, 0.5)
) %>%
group_by(g) %>%
mutate(is_train = row_number() <= 20) %>%
ungroup()
train1 <- filter(d1, is_train)
test1 <- filter(d1, !is_train)
# Fit BHM (matches the DGP)
fit1_hier <- fit_model(mod_hier, list(
N = nrow(train1), J = J, g = train1$g, x = train1$x, y = train1$y,
N_test = nrow(test1), g_test = test1$g, x_test = test1$x
))
# Fit BNN (no group structure — disadvantaged)
fit1_bnn <- fit_model(mod_bnn, list(
N = nrow(train1), x = train1$x, y = train1$y,
N_test = nrow(test1), x_test = test1$x, H = 8
))
yrep1_hier <- extract_yrep(fit1_hier)
yrep1_bnn <- extract_yrep(fit1_bnn)
# Ensemble: pool posterior predictive draws 50/50
n_draws <- ncol(yrep1_hier)
ens_idx <- sample(c(TRUE, FALSE), n_draws, replace = TRUE)
yrep1_ens <- yrep1_hier
yrep1_ens[, !ens_idx] <- yrep1_bnn[, !ens_idx]
crps1_hier <- mean(crps_sample(test1$y, yrep1_hier))
crps1_bnn <- mean(crps_sample(test1$y, yrep1_bnn))
crps1_ens <- mean(crps_sample(test1$y, yrep1_ens))
cat(sprintf(" BHM mean CRPS: %.4f\n", crps1_hier))
cat(sprintf(" BNN mean CRPS: %.4f\n", crps1_bnn))
cat(sprintf(" ENS mean CRPS: %.4f\n", crps1_ens))
# ============================================================================
# CASE 2: Complex nonlinear DGP — BNN should win
# ============================================================================
cat("\n══════════════════════════════════════════════════\n")
cat("CASE 2: Complex nonlinear DGP (BNN advantage)\n")
cat("══════════════════════════════════════════════════\n")
N2 <- 300
d2 <- tibble(
g = rep(1:J, length.out = N2),
x = runif(N2, -3, 3)
) %>%
mutate(
mu = 2 * sin(1.5 * x) + 0.5 * cos(3 * x),
y = rnorm(n(), mu, 0.4)
) %>%
group_by(g) %>%
mutate(is_train = row_number() <= floor(0.7 * n())) %>%
ungroup()
train2 <- filter(d2, is_train)
test2 <- filter(d2, !is_train)
# Fit BHM (linear — misspecified for nonlinear signal)
fit2_hier <- fit_model(mod_hier, list(
N = nrow(train2), J = J, g = train2$g, x = train2$x, y = train2$y,
N_test = nrow(test2), g_test = test2$g, x_test = test2$x
))
# Fit BNN (flexible nonlinear — matches DGP)
fit2_bnn <- fit_model(mod_bnn, list(
N = nrow(train2), x = train2$x, y = train2$y,
N_test = nrow(test2), x_test = test2$x, H = 8
))
yrep2_hier <- extract_yrep(fit2_hier)
yrep2_bnn <- extract_yrep(fit2_bnn)
# Ensemble
n_draws2 <- ncol(yrep2_hier)
ens_idx2 <- sample(c(TRUE, FALSE), n_draws2, replace = TRUE)
yrep2_ens <- yrep2_hier
yrep2_ens[, !ens_idx2] <- yrep2_bnn[, !ens_idx2]
crps2_hier <- mean(crps_sample(test2$y, yrep2_hier))
crps2_bnn <- mean(crps_sample(test2$y, yrep2_bnn))
crps2_ens <- mean(crps_sample(test2$y, yrep2_ens))
cat(sprintf(" BHM mean CRPS: %.4f\n", crps2_hier))
cat(sprintf(" BNN mean CRPS: %.4f\n", crps2_bnn))
cat(sprintf(" ENS mean CRPS: %.4f\n", crps2_ens))
# ============================================================================
# CASE 3: Summary — ensemble diversification
# ============================================================================
cat("\n══════════════════════════════════════════════════\n")
cat("SUMMARY: Ensemble diversification across DGPs\n")
cat("══════════════════════════════════════════════════\n")
results <- tibble(
DGP = c(rep("Grouped linear", 3), rep("Complex nonlinear", 3)),
Model = rep(c("BHM", "BNN", "Ensemble"), 2),
CRPS = c(crps1_hier, crps1_bnn, crps1_ens,
crps2_hier, crps2_bnn, crps2_ens)
)
print(results)
# ── Visualization ───────────────────────────────────────────────────────────
p <- ggplot(results, aes(x = Model, y = CRPS, fill = Model)) +
geom_col(alpha = 0.85) +
facet_wrap(~ DGP, scales = "free_y") +
scale_fill_manual(values = c(
"BHM" = "#0f3460",
"BNN" = "#e94560",
"Ensemble" = "#16213e"
)) +
labs(
title = "Out-of-Sample CRPS: BHM vs BNN vs Ensemble",
subtitle = "Lower CRPS = better probabilistic predictions",
y = "Mean Test CRPS",
x = NULL
) +
theme_minimal(base_size = 14) +
theme(legend.position = "none",
strip.text = element_text(face = "bold"))
ggsave("analysis/bhm_bnn_ensemble_crps.png", p,
width = 9, height = 5, dpi = 150, bg = "white")
cat("\nPlot saved to analysis/bhm_bnn_ensemble_crps.png\n")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment