Skip to content

Instantly share code, notes, and snippets.

@StaffanBetner
Last active March 8, 2024 20:56
Show Gist options
  • Save StaffanBetner/7c9fef5ba146db9dd26393a409bdadb8 to your computer and use it in GitHub Desktop.
Save StaffanBetner/7c9fef5ba146db9dd26393a409bdadb8 to your computer and use it in GitHub Desktop.
Sample Bootstrap Weights within Stan (once for every iteration)
#ifndef DIRICHLET_RNG_WRAPPER_HPP
#define DIRICHLET_RNG_WRAPPER_HPP
#include <stan/math.hpp>
#include <boost/random/mersenne_twister.hpp>
#include <chrono>
#include <Eigen/Dense>
#include <iostream>
// Declare an integer to keep track of the iteration count
static int itct = 0;
// Increment the counter
inline void add_iter(std::ostream* pstream__) {
itct += 1;
}
// Retrieve the current count
inline int get_iter(std::ostream* pstream__) {
return itct;
}
// Generate Dirichlet draws, with iteration checking
Eigen::VectorXd dirichlet_rng_wrapper(const Eigen::VectorXd& alpha, std::ostream* pstream__) {
static Eigen::VectorXd last_draw = Eigen::VectorXd::Zero(alpha.size()); // Initialize with zeros
static int last_itct = -1; // Start with -1 to ensure it differs from itct initially
if (itct != last_itct) {
// It's a new iteration, generate new Dirichlet draws
unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
boost::random::mt19937 rng(seed);
last_draw = stan::math::dirichlet_rng(alpha, rng);
// Update the iteration counter
last_itct = itct;
}
// Increment the iteration count is handled outside this function
return last_draw;
}
#endif // DIRICHLET_RNG_WRAPPER_HPP
// generated with brms 2.20.6
functions {
void add_iter(); // ~*~THIS IS NEW~*~
int get_iter(); // ~*~THIS IS NEW~*~
vector dirichlet_rng_wrapper(vector alpha); // ~*~THIS IS NEW~*~
}
data {
int<lower=1> N; // total number of observations
vector[N] Y; // response variable
array[N] int<lower=-1,upper=2> cens; // indicates censoring
}
transformed data {
vector[N] alpha = rep_vector(1.0, N); // Dirichlet parameters, all ones for uniform distribution ~*~THIS IS NEW~*~
}
parameters {
real Intercept; // temporary intercept for centered predictors
real<lower=0> shape; // shape parameter
}
transformed parameters {
real lprior = 0; // prior contributions to the log posterior
lprior += student_t_lpdf(Intercept | 3, 6.3, 2.5);
lprior += gamma_lpdf(shape | 0.01, 0.01);
}
model {
// likelihood including constants
if (!prior_only) {
vector[N] weights = dirichlet_rng_wrapper(alpha)*N; // ~*~THIS IS NEW~*~
// initialize linear predictor term
vector[N] mu = rep_vector(0.0, N);
mu += Intercept;
mu = exp(mu);
for (n in 1:N) {
// special treatment of censored data
if (cens[n] == 0) {
target += weights[n] * weibull_lpdf(Y[n] | shape, mu[n] / tgamma(1 + 1 / shape));
} else if (cens[n] == 1) {
target += weights[n] * weibull_lccdf(Y[n] | shape, mu[n] / tgamma(1 + 1 / shape));
} else if (cens[n] == -1) {
target += weights[n] * weibull_lcdf(Y[n] | shape, mu[n] / tgamma(1 + 1 / shape));
}
}
}
// priors including constants
target += lprior;
}
generated quantities {
// actual population-level intercept
real b_Intercept = Intercept;
add_iter(); // update the counter each iteration -- ~*~THIS IS NEW~*~
}
---
title: "Bootstrapping within Stan"
output: html_notebook
---
```{r}
# Function available here: https://gist.github.com/StaffanBetner/d632bd70686aebd8e488bdc1a454c8e2
source("https://gist.githubusercontent.com/StaffanBetner/d632bd70686aebd8e488bdc1a454c8e2/raw/")
pkg_load(tidyverse, rio, magrittr, janitor, qs, fwb, posterior, ggdist, brms, cmdstanr, rstan, here)
data("bearingcage", package = "fwb")
```
```{r}
bearingcage
```
Let's start with some classic example:
```{r}
fit <- survival::survreg(survival::Surv(hours, failure) ~ 1,
data = bearingcage,
dist = "weibull")
fit
```
```{r}
summary(fit)
```
```{r}
MASS::mvrnorm(n = 8000, mu = summary(fit)$table[,1], Sigma = vcov(fit), empirical = T) %>% rvar() ->
mle_normal_approx
mle_normal_approx[2] <- 1/exp(mle_normal_approx[2])
mle_normal_approx %>%
setNames(c("log(eta)","beta")) %>%
t() %>%
as_tibble() %>%
mutate(method = "mle",.before = `log(eta)`) ->
mle_normal_approx
```
```{r}
mle_normal_approx
```
```{r}
weibull_est <- function(data, w) {
fit <- survival::survreg(survival::Surv(hours, failure) ~ 1,
data = data, weights = w,
dist = "weibull")
c("log(eta)" = unname((coef(fit))), beta = 1/fit$scale)
}
```
```{r}
fwb_est <- fwb(bearingcage, statistic = weibull_est,
R = 8000, verbose = TRUE) %>% qcache("fwb_est")
```
```{r}
(fwb_est$t %>%
rvar() %>%
t() %>%
as_tibble()%>%
mutate(method = "bootstrap", .before =`log(eta)`) %>%
bind_rows(mle_normal_approx) ->
estimates)
```
```{r}
estimates %>%
pivot_longer(cols = -1) %>%
ggplot(aes(xdist=value, y=method, group=method))+
facet_wrap(~name)+
stat_slab()
```
Let's add the Bayesian case(s)
```{r}
brm(formula = hours|cens(!failure)~1,
family = weibull,
data = bearingcage,
backend = "cmdstanr",
cores = 4,
iter = 3000,
warmup = 1000,
file = "fit_brm",
control= list(adapt_delta=.95)) ->
fit_brm
```
```{r}
fit_brm %>% as_draws_rvars() -> fit_brm_rvars
```
```{r}
(estimates %>%
add_row(method = "hmc", `log(eta)` = fit_brm_rvars$b_Intercept, beta = fit_brm_rvars$shape) ->
estimates)
```
```{r}
estimates %>%
pivot_longer(cols = -1) %>%
ggplot(aes(xdist=value, y=method, group=method))+
facet_wrap(~name)+
stat_slab()
```
# HMC bootstrapping (here be dragons)
```{r}
make_stancode(formula = hours|cens(!failure)+weights(1)~1,
family = weibull,
data = bearingcage,
save_model = "original_stan_code.stan")
```
```{r}
cmdstan_model("modified_stan_code.stan", user_header = here('iterfuns.hpp')) -> modified_model
```
```{r}
# RStan code:
# stan_model(stanc_ret = stanc("modified_stan_code.stan",
# allow_undefined = TRUE),
# includes = paste0('\n#include "', here('iterfuns.hpp'), '"\n')) ->
# outcome_model
```
<!-- When supplying weights externally: cmdstanr seems to reset memory when it goes from warmup to sampling, so only max(iter_sampling, iter_warmup) bootstrap draws are needed. rstan shares memory over all chains and between warmup and sampling, and also uses one iteration for initial something -->
```{r}
make_standata(formula = hours|cens(!failure)+weights(1)~1,
family = weibull,
data=bearingcage)[-3] -> # [3] is weights
standata_bootstrap
```
```{r}
modified_model$sample(data = standata_bootstrap,
chains = 4,
iter_warmup = 1000,
iter_sampling = 2000,
adapt_delta = 0.9995,
refresh = 50L,
parallel_chains = 3) %>%
qcache("outcome_samples_cmdstan") ->
outcome_samples_cmdstan
```
NO DIVERGENCES!! HOLY MACARONI!
```{r}
outcome_samples_cmdstan$output_files() %>%
rstan::read_stan_csv() ->
rstan_fit
outcome_samples_cmdstan ->
attributes(rstan_fit)$CmdStanModel
modified_model_brms <- brm(formula = hours|cens(!failure)+weights(1)~1,
family = weibull,
data=bearingcage,
empty = TRUE)
rstan_fit ->
modified_model_brms$fit
rename_pars(modified_model_brms) ->
modified_model_brms
modified_model_brms %>% qcache("modified_model_brms") -> modified_model_brms
```
```{r}
outcome_samples_cmdstan$draws() %>% as_draws_rvars() -> hmc_bootstrap_rvars
```
```{r}
(estimates %>%
add_row(method = "hmc bootstrap", `log(eta)` = hmc_bootstrap_rvars$b_Intercept, beta = hmc_bootstrap_rvars$shape) ->
estimates)
```
```{r}
estimates %>%
slice(2,1,3,4) %>%
mutate(method = method %>% factor(levels = method) %>% fct_rev) %>%
pivot_longer(cols = -1) %>%
ggplot(aes(xdist=value, y=method, group=method))+
facet_wrap(~name, scales = "free_x")+
stat_slab()+
theme_ggdist()+labs(x=NULL, title="Bearing Cage data (Weibull model)", y="Estimation Method")
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment