Skip to content

Instantly share code, notes, and snippets.

@statwonk
Created March 25, 2026 16:01
Show Gist options
  • Select an option

  • Save statwonk/ead845b7811dd5347b80d0fb68c7757c to your computer and use it in GitHub Desktop.

Select an option

Save statwonk/ead845b7811dd5347b80d0fb68c7757c to your computer and use it in GitHub Desktop.
Hierarchical BNN Comparison: BHM vs BNN vs HBNN vs Ensemble on CRPS (R reprex with cmdstanr)
# ============================================================================
# Hierarchical BNN: Combining Structure + Flexibility
# ============================================================================
# Four models compared across three DGPs:
# BHM — group-varying intercept + slope (linear)
# BNN — one hidden layer, no group structure
# HBNN — group-varying intercepts + one hidden layer (the hybrid)
# ENS — 50/50 mixture of BHM + BNN posterior predictive draws
#
# Three DGPs:
# 1. Grouped linear → BHM wins
# 2. Complex nonlinear → BNN wins
# 3. Grouped + nonlinear → HBNN wins
#
# Requirements: cmdstanr, posterior, scoringRules, dplyr, tibble, ggplot2
# ============================================================================
library(cmdstanr)
library(posterior)
library(scoringRules)
library(dplyr)
library(tibble)
library(ggplot2)
set.seed(42)
# ── Stan model code ──────────────────────────────────────────────────────────
# 1. Bayesian hierarchical model: group-varying intercept + slope
hier_stan_code <- "
data {
int<lower=1> N;
int<lower=1> J;
array[N] int<lower=1, upper=J> g;
vector[N] x;
vector[N] y;
int<lower=1> N_test;
array[N_test] int<lower=1, upper=J> g_test;
vector[N_test] x_test;
}
parameters {
real alpha_0;
real beta_0;
real<lower=0> sigma_alpha;
real<lower=0> sigma_beta;
vector[J] alpha_raw;
vector[J] beta_raw;
real<lower=0> sigma;
}
transformed parameters {
vector[J] alpha = alpha_0 + sigma_alpha * alpha_raw;
vector[J] beta = beta_0 + sigma_beta * beta_raw;
}
model {
alpha_0 ~ normal(0, 2);
beta_0 ~ normal(0, 2);
sigma_alpha ~ normal(0, 1);
sigma_beta ~ normal(0, 1);
alpha_raw ~ std_normal();
beta_raw ~ std_normal();
sigma ~ normal(0, 1);
vector[N] mu;
for (n in 1:N)
mu[n] = alpha[g[n]] + beta[g[n]] * x[n];
y ~ normal(mu, sigma);
}
generated quantities {
vector[N_test] yrep_test;
for (n in 1:N_test) {
real mu_n = alpha[g_test[n]] + beta[g_test[n]] * x_test[n];
yrep_test[n] = normal_rng(mu_n, sigma);
}
}
"
# 2. Bayesian neural network: 1 hidden layer, NO group structure
bnn_stan_code <- "
data {
int<lower=1> N;
vector[N] x;
vector[N] y;
int<lower=1> N_test;
vector[N_test] x_test;
int<lower=1> H;
}
parameters {
vector[H] w1;
vector[H] b1;
vector[H] w2;
real b2;
real<lower=0> sigma;
}
model {
w1 ~ normal(0, 1);
b1 ~ normal(0, 1);
w2 ~ normal(0, 1);
b2 ~ normal(0, 1);
sigma ~ normal(0, 1);
vector[N] mu;
for (n in 1:N) {
vector[H] h = tanh(w1 * x[n] + b1);
mu[n] = b2 + dot_product(w2, h);
}
y ~ normal(mu, sigma);
}
generated quantities {
vector[N_test] yrep_test;
for (n in 1:N_test) {
vector[H] h = tanh(w1 * x_test[n] + b1);
yrep_test[n] = normal_rng(b2 + dot_product(w2, h), sigma);
}
}
"
# 3. Hierarchical BNN: group intercepts + hidden layer
hbnn_stan_code <- "
data {
int<lower=1> N;
int<lower=1> J;
array[N] int<lower=1, upper=J> g;
vector[N] x;
vector[N] y;
int<lower=1> N_test;
array[N_test] int<lower=1, upper=J> g_test;
vector[N_test] x_test;
int<lower=1> H;
}
parameters {
real alpha_0;
real<lower=0> sigma_alpha;
vector[J] alpha_raw;
vector[H] w1;
vector[H] b1;
vector[H] w2;
real b2;
real<lower=0> sigma;
}
transformed parameters {
vector[J] alpha = alpha_0 + sigma_alpha * alpha_raw;
}
model {
alpha_0 ~ normal(0, 2);
sigma_alpha ~ normal(0, 1);
alpha_raw ~ std_normal();
w1 ~ normal(0, 1);
b1 ~ normal(0, 1);
w2 ~ normal(0, 1);
b2 ~ normal(0, 1);
sigma ~ normal(0, 1);
vector[N] mu;
for (n in 1:N) {
vector[H] h = tanh(w1 * x[n] + b1);
mu[n] = alpha[g[n]] + b2 + dot_product(w2, h);
}
y ~ normal(mu, sigma);
}
generated quantities {
vector[N_test] yrep_test;
for (n in 1:N_test) {
vector[H] h = tanh(w1 * x_test[n] + b1);
real mu_n = alpha[g_test[n]] + b2 + dot_product(w2, h);
yrep_test[n] = normal_rng(mu_n, sigma);
}
}
"
# ── Compile Stan models ─────────────────────────────────────────────────────
cat("Compiling Stan models...\n")
mod_hier <- cmdstan_model(write_stan_file(hier_stan_code))
mod_bnn <- cmdstan_model(write_stan_file(bnn_stan_code))
mod_hbnn <- cmdstan_model(write_stan_file(hbnn_stan_code))
# ── Helpers ──────────────────────────────────────────────────────────────────
extract_yrep <- function(fit) {
mat <- fit$draws("yrep_test", format = "matrix")
t(mat[, grepl("^yrep_test\\[", colnames(mat)), drop = FALSE])
}
fit_model <- function(mod, data, ...) {
mod$sample(
data = data,
chains = 2, parallel_chains = 2,
iter_warmup = 400, iter_sampling = 400,
refresh = 0, show_messages = FALSE,
...
)
}
make_ensemble <- function(yrep_a, yrep_b) {
n_draws <- ncol(yrep_a)
idx <- sample(c(TRUE, FALSE), n_draws, replace = TRUE)
out <- yrep_a
out[, !idx] <- yrep_b[, !idx]
out
}
run_dgp <- function(dgp_name, train, test, J) {
cat(sprintf("\n══════════════════════════════════════════════════\n"))
cat(sprintf(" %s\n", dgp_name))
cat(sprintf("══════════════════════════════════════════════════\n"))
hier_data <- list(
N = nrow(train), J = J, g = train$g, x = train$x, y = train$y,
N_test = nrow(test), g_test = test$g, x_test = test$x
)
bnn_data <- list(
N = nrow(train), x = train$x, y = train$y,
N_test = nrow(test), x_test = test$x, H = 8
)
hbnn_data <- c(hier_data, list(H = 8))
fit_h <- fit_model(mod_hier, hier_data)
fit_b <- fit_model(mod_bnn, bnn_data)
fit_hb <- fit_model(mod_hbnn, hbnn_data)
yrep_h <- extract_yrep(fit_h)
yrep_b <- extract_yrep(fit_b)
yrep_hb <- extract_yrep(fit_hb)
yrep_e <- make_ensemble(yrep_h, yrep_b)
crps_h <- mean(crps_sample(test$y, yrep_h))
crps_b <- mean(crps_sample(test$y, yrep_b))
crps_hb <- mean(crps_sample(test$y, yrep_hb))
crps_e <- mean(crps_sample(test$y, yrep_e))
cat(sprintf(" BHM CRPS: %.4f\n", crps_h))
cat(sprintf(" BNN CRPS: %.4f\n", crps_b))
cat(sprintf(" HBNN CRPS: %.4f\n", crps_hb))
cat(sprintf(" ENS CRPS: %.4f\n", crps_e))
tibble(
DGP = dgp_name,
Model = c("BHM", "BNN", "HBNN", "Ensemble"),
CRPS = c(crps_h, crps_b, crps_hb, crps_e)
)
}
# ============================================================================
# DGP 1: Grouped linear — BHM advantage
# ============================================================================
J <- 10; n_per <- 30; N1 <- J * n_per
alpha_true <- rnorm(J, 0, 1.5)
beta_true <- rnorm(J, 2, 0.8)
d1 <- tibble(g = rep(1:J, each = n_per), x = runif(N1, -2, 2)) %>%
mutate(mu = alpha_true[g] + beta_true[g] * x,
y = rnorm(n(), mu, 0.5)) %>%
group_by(g) %>% mutate(is_train = row_number() <= 20) %>% ungroup()
res1 <- run_dgp("Grouped linear",
filter(d1, is_train), filter(d1, !is_train), J)
# ============================================================================
# DGP 2: Complex nonlinear, no group structure — BNN advantage
# ============================================================================
N2 <- 300
d2 <- tibble(g = rep(1:J, length.out = N2), x = runif(N2, -3, 3)) %>%
mutate(mu = 2 * sin(1.5 * x) + 0.5 * cos(3 * x),
y = rnorm(n(), mu, 0.4)) %>%
group_by(g) %>% mutate(is_train = row_number() <= floor(0.7 * n())) %>% ungroup()
res2 <- run_dgp("Complex nonlinear (no groups)",
filter(d2, is_train), filter(d2, !is_train), J)
# ============================================================================
# DGP 3: Grouped + complex nonlinear — HBNN advantage
# ============================================================================
alpha_true3 <- rnorm(J, 0, 1.5)
d3 <- tibble(g = rep(1:J, each = n_per), x = runif(N1, -3, 3)) %>%
mutate(mu = alpha_true3[g] + 2 * sin(1.5 * x) + 0.5 * cos(3 * x),
y = rnorm(n(), mu, 0.4)) %>%
group_by(g) %>% mutate(is_train = row_number() <= 20) %>% ungroup()
res3 <- run_dgp("Grouped + complex nonlinear",
filter(d3, is_train), filter(d3, !is_train), J)
# ============================================================================
# Summary
# ============================================================================
cat("\n══════════════════════════════════════════════════\n")
cat(" FULL SUMMARY\n")
cat("══════════════════════════════════════════════════\n")
results <- bind_rows(res1, res2, res3)
print(results, n = 12)
# ── Visualization ───────────────────────────────────────────────────────────
results$Model <- factor(results$Model, levels = c("BHM", "BNN", "HBNN", "Ensemble"))
results$DGP <- factor(results$DGP,
levels = c("Grouped linear",
"Complex nonlinear (no groups)",
"Grouped + complex nonlinear"))
p <- ggplot(results, aes(x = Model, y = CRPS, fill = Model)) +
geom_col(alpha = 0.85) +
facet_wrap(~ DGP, scales = "free_y") +
scale_fill_manual(values = c(
"BHM" = "#0f3460",
"BNN" = "#e94560",
"HBNN" = "#533483",
"Ensemble" = "#16213e"
)) +
labs(
title = "Out-of-Sample CRPS: BHM vs BNN vs HBNN vs Ensemble",
subtitle = "Lower CRPS = better probabilistic predictions",
y = "Mean Test CRPS", x = NULL
) +
theme_minimal(base_size = 13) +
theme(legend.position = "none",
strip.text = element_text(face = "bold", size = 10))
ggsave("analysis/hbnn_comparison_crps.png", p,
width = 11, height = 5, dpi = 150, bg = "white")
cat("\nPlot saved to analysis/hbnn_comparison_crps.png\n")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment