Created
March 25, 2026 16:01
-
-
Save statwonk/ead845b7811dd5347b80d0fb68c7757c to your computer and use it in GitHub Desktop.
Hierarchical BNN Comparison: BHM vs BNN vs HBNN vs Ensemble on CRPS (R reprex with cmdstanr)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # ============================================================================ | |
| # Hierarchical BNN: Combining Structure + Flexibility | |
| # ============================================================================ | |
| # Four models compared across three DGPs: | |
| # BHM — group-varying intercept + slope (linear) | |
| # BNN — one hidden layer, no group structure | |
| # HBNN — group-varying intercepts + one hidden layer (the hybrid) | |
| # ENS — 50/50 mixture of BHM + BNN posterior predictive draws | |
| # | |
| # Three DGPs: | |
| # 1. Grouped linear → BHM wins | |
| # 2. Complex nonlinear → BNN wins | |
| # 3. Grouped + nonlinear → HBNN wins | |
| # | |
| # Requirements: cmdstanr, posterior, scoringRules, dplyr, tibble, ggplot2 | |
| # ============================================================================ | |
| library(cmdstanr) | |
| library(posterior) | |
| library(scoringRules) | |
| library(dplyr) | |
| library(tibble) | |
| library(ggplot2) | |
| set.seed(42) | |
| # ── Stan model code ────────────────────────────────────────────────────────── | |
| # 1. Bayesian hierarchical model: group-varying intercept + slope | |
| hier_stan_code <- " | |
| data { | |
| int<lower=1> N; | |
| int<lower=1> J; | |
| array[N] int<lower=1, upper=J> g; | |
| vector[N] x; | |
| vector[N] y; | |
| int<lower=1> N_test; | |
| array[N_test] int<lower=1, upper=J> g_test; | |
| vector[N_test] x_test; | |
| } | |
| parameters { | |
| real alpha_0; | |
| real beta_0; | |
| real<lower=0> sigma_alpha; | |
| real<lower=0> sigma_beta; | |
| vector[J] alpha_raw; | |
| vector[J] beta_raw; | |
| real<lower=0> sigma; | |
| } | |
| transformed parameters { | |
| vector[J] alpha = alpha_0 + sigma_alpha * alpha_raw; | |
| vector[J] beta = beta_0 + sigma_beta * beta_raw; | |
| } | |
| model { | |
| alpha_0 ~ normal(0, 2); | |
| beta_0 ~ normal(0, 2); | |
| sigma_alpha ~ normal(0, 1); | |
| sigma_beta ~ normal(0, 1); | |
| alpha_raw ~ std_normal(); | |
| beta_raw ~ std_normal(); | |
| sigma ~ normal(0, 1); | |
| vector[N] mu; | |
| for (n in 1:N) | |
| mu[n] = alpha[g[n]] + beta[g[n]] * x[n]; | |
| y ~ normal(mu, sigma); | |
| } | |
| generated quantities { | |
| vector[N_test] yrep_test; | |
| for (n in 1:N_test) { | |
| real mu_n = alpha[g_test[n]] + beta[g_test[n]] * x_test[n]; | |
| yrep_test[n] = normal_rng(mu_n, sigma); | |
| } | |
| } | |
| " | |
| # 2. Bayesian neural network: 1 hidden layer, NO group structure | |
| bnn_stan_code <- " | |
| data { | |
| int<lower=1> N; | |
| vector[N] x; | |
| vector[N] y; | |
| int<lower=1> N_test; | |
| vector[N_test] x_test; | |
| int<lower=1> H; | |
| } | |
| parameters { | |
| vector[H] w1; | |
| vector[H] b1; | |
| vector[H] w2; | |
| real b2; | |
| real<lower=0> sigma; | |
| } | |
| model { | |
| w1 ~ normal(0, 1); | |
| b1 ~ normal(0, 1); | |
| w2 ~ normal(0, 1); | |
| b2 ~ normal(0, 1); | |
| sigma ~ normal(0, 1); | |
| vector[N] mu; | |
| for (n in 1:N) { | |
| vector[H] h = tanh(w1 * x[n] + b1); | |
| mu[n] = b2 + dot_product(w2, h); | |
| } | |
| y ~ normal(mu, sigma); | |
| } | |
| generated quantities { | |
| vector[N_test] yrep_test; | |
| for (n in 1:N_test) { | |
| vector[H] h = tanh(w1 * x_test[n] + b1); | |
| yrep_test[n] = normal_rng(b2 + dot_product(w2, h), sigma); | |
| } | |
| } | |
| " | |
| # 3. Hierarchical BNN: group intercepts + hidden layer | |
| hbnn_stan_code <- " | |
| data { | |
| int<lower=1> N; | |
| int<lower=1> J; | |
| array[N] int<lower=1, upper=J> g; | |
| vector[N] x; | |
| vector[N] y; | |
| int<lower=1> N_test; | |
| array[N_test] int<lower=1, upper=J> g_test; | |
| vector[N_test] x_test; | |
| int<lower=1> H; | |
| } | |
| parameters { | |
| real alpha_0; | |
| real<lower=0> sigma_alpha; | |
| vector[J] alpha_raw; | |
| vector[H] w1; | |
| vector[H] b1; | |
| vector[H] w2; | |
| real b2; | |
| real<lower=0> sigma; | |
| } | |
| transformed parameters { | |
| vector[J] alpha = alpha_0 + sigma_alpha * alpha_raw; | |
| } | |
| model { | |
| alpha_0 ~ normal(0, 2); | |
| sigma_alpha ~ normal(0, 1); | |
| alpha_raw ~ std_normal(); | |
| w1 ~ normal(0, 1); | |
| b1 ~ normal(0, 1); | |
| w2 ~ normal(0, 1); | |
| b2 ~ normal(0, 1); | |
| sigma ~ normal(0, 1); | |
| vector[N] mu; | |
| for (n in 1:N) { | |
| vector[H] h = tanh(w1 * x[n] + b1); | |
| mu[n] = alpha[g[n]] + b2 + dot_product(w2, h); | |
| } | |
| y ~ normal(mu, sigma); | |
| } | |
| generated quantities { | |
| vector[N_test] yrep_test; | |
| for (n in 1:N_test) { | |
| vector[H] h = tanh(w1 * x_test[n] + b1); | |
| real mu_n = alpha[g_test[n]] + b2 + dot_product(w2, h); | |
| yrep_test[n] = normal_rng(mu_n, sigma); | |
| } | |
| } | |
| " | |
| # ── Compile Stan models ───────────────────────────────────────────────────── | |
| cat("Compiling Stan models...\n") | |
| mod_hier <- cmdstan_model(write_stan_file(hier_stan_code)) | |
| mod_bnn <- cmdstan_model(write_stan_file(bnn_stan_code)) | |
| mod_hbnn <- cmdstan_model(write_stan_file(hbnn_stan_code)) | |
| # ── Helpers ────────────────────────────────────────────────────────────────── | |
| extract_yrep <- function(fit) { | |
| mat <- fit$draws("yrep_test", format = "matrix") | |
| t(mat[, grepl("^yrep_test\\[", colnames(mat)), drop = FALSE]) | |
| } | |
| fit_model <- function(mod, data, ...) { | |
| mod$sample( | |
| data = data, | |
| chains = 2, parallel_chains = 2, | |
| iter_warmup = 400, iter_sampling = 400, | |
| refresh = 0, show_messages = FALSE, | |
| ... | |
| ) | |
| } | |
| make_ensemble <- function(yrep_a, yrep_b) { | |
| n_draws <- ncol(yrep_a) | |
| idx <- sample(c(TRUE, FALSE), n_draws, replace = TRUE) | |
| out <- yrep_a | |
| out[, !idx] <- yrep_b[, !idx] | |
| out | |
| } | |
| run_dgp <- function(dgp_name, train, test, J) { | |
| cat(sprintf("\n══════════════════════════════════════════════════\n")) | |
| cat(sprintf(" %s\n", dgp_name)) | |
| cat(sprintf("══════════════════════════════════════════════════\n")) | |
| hier_data <- list( | |
| N = nrow(train), J = J, g = train$g, x = train$x, y = train$y, | |
| N_test = nrow(test), g_test = test$g, x_test = test$x | |
| ) | |
| bnn_data <- list( | |
| N = nrow(train), x = train$x, y = train$y, | |
| N_test = nrow(test), x_test = test$x, H = 8 | |
| ) | |
| hbnn_data <- c(hier_data, list(H = 8)) | |
| fit_h <- fit_model(mod_hier, hier_data) | |
| fit_b <- fit_model(mod_bnn, bnn_data) | |
| fit_hb <- fit_model(mod_hbnn, hbnn_data) | |
| yrep_h <- extract_yrep(fit_h) | |
| yrep_b <- extract_yrep(fit_b) | |
| yrep_hb <- extract_yrep(fit_hb) | |
| yrep_e <- make_ensemble(yrep_h, yrep_b) | |
| crps_h <- mean(crps_sample(test$y, yrep_h)) | |
| crps_b <- mean(crps_sample(test$y, yrep_b)) | |
| crps_hb <- mean(crps_sample(test$y, yrep_hb)) | |
| crps_e <- mean(crps_sample(test$y, yrep_e)) | |
| cat(sprintf(" BHM CRPS: %.4f\n", crps_h)) | |
| cat(sprintf(" BNN CRPS: %.4f\n", crps_b)) | |
| cat(sprintf(" HBNN CRPS: %.4f\n", crps_hb)) | |
| cat(sprintf(" ENS CRPS: %.4f\n", crps_e)) | |
| tibble( | |
| DGP = dgp_name, | |
| Model = c("BHM", "BNN", "HBNN", "Ensemble"), | |
| CRPS = c(crps_h, crps_b, crps_hb, crps_e) | |
| ) | |
| } | |
| # ============================================================================ | |
| # DGP 1: Grouped linear — BHM advantage | |
| # ============================================================================ | |
| J <- 10; n_per <- 30; N1 <- J * n_per | |
| alpha_true <- rnorm(J, 0, 1.5) | |
| beta_true <- rnorm(J, 2, 0.8) | |
| d1 <- tibble(g = rep(1:J, each = n_per), x = runif(N1, -2, 2)) %>% | |
| mutate(mu = alpha_true[g] + beta_true[g] * x, | |
| y = rnorm(n(), mu, 0.5)) %>% | |
| group_by(g) %>% mutate(is_train = row_number() <= 20) %>% ungroup() | |
| res1 <- run_dgp("Grouped linear", | |
| filter(d1, is_train), filter(d1, !is_train), J) | |
| # ============================================================================ | |
| # DGP 2: Complex nonlinear, no group structure — BNN advantage | |
| # ============================================================================ | |
| N2 <- 300 | |
| d2 <- tibble(g = rep(1:J, length.out = N2), x = runif(N2, -3, 3)) %>% | |
| mutate(mu = 2 * sin(1.5 * x) + 0.5 * cos(3 * x), | |
| y = rnorm(n(), mu, 0.4)) %>% | |
| group_by(g) %>% mutate(is_train = row_number() <= floor(0.7 * n())) %>% ungroup() | |
| res2 <- run_dgp("Complex nonlinear (no groups)", | |
| filter(d2, is_train), filter(d2, !is_train), J) | |
| # ============================================================================ | |
| # DGP 3: Grouped + complex nonlinear — HBNN advantage | |
| # ============================================================================ | |
| alpha_true3 <- rnorm(J, 0, 1.5) | |
| d3 <- tibble(g = rep(1:J, each = n_per), x = runif(N1, -3, 3)) %>% | |
| mutate(mu = alpha_true3[g] + 2 * sin(1.5 * x) + 0.5 * cos(3 * x), | |
| y = rnorm(n(), mu, 0.4)) %>% | |
| group_by(g) %>% mutate(is_train = row_number() <= 20) %>% ungroup() | |
| res3 <- run_dgp("Grouped + complex nonlinear", | |
| filter(d3, is_train), filter(d3, !is_train), J) | |
| # ============================================================================ | |
| # Summary | |
| # ============================================================================ | |
| cat("\n══════════════════════════════════════════════════\n") | |
| cat(" FULL SUMMARY\n") | |
| cat("══════════════════════════════════════════════════\n") | |
| results <- bind_rows(res1, res2, res3) | |
| print(results, n = 12) | |
| # ── Visualization ─────────────────────────────────────────────────────────── | |
| results$Model <- factor(results$Model, levels = c("BHM", "BNN", "HBNN", "Ensemble")) | |
| results$DGP <- factor(results$DGP, | |
| levels = c("Grouped linear", | |
| "Complex nonlinear (no groups)", | |
| "Grouped + complex nonlinear")) | |
| p <- ggplot(results, aes(x = Model, y = CRPS, fill = Model)) + | |
| geom_col(alpha = 0.85) + | |
| facet_wrap(~ DGP, scales = "free_y") + | |
| scale_fill_manual(values = c( | |
| "BHM" = "#0f3460", | |
| "BNN" = "#e94560", | |
| "HBNN" = "#533483", | |
| "Ensemble" = "#16213e" | |
| )) + | |
| labs( | |
| title = "Out-of-Sample CRPS: BHM vs BNN vs HBNN vs Ensemble", | |
| subtitle = "Lower CRPS = better probabilistic predictions", | |
| y = "Mean Test CRPS", x = NULL | |
| ) + | |
| theme_minimal(base_size = 13) + | |
| theme(legend.position = "none", | |
| strip.text = element_text(face = "bold", size = 10)) | |
| ggsave("analysis/hbnn_comparison_crps.png", p, | |
| width = 11, height = 5, dpi = 150, bg = "white") | |
| cat("\nPlot saved to analysis/hbnn_comparison_crps.png\n") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment