Skip to content

Instantly share code, notes, and snippets.

@martinmodrak
Created May 27, 2021 12:53
Show Gist options
  • Save martinmodrak/7ec873aa20aa6b08f4d828572241e10e to your computer and use it in GitHub Desktop.
Save martinmodrak/7ec873aa20aa6b08f4d828572241e10e to your computer and use it in GitHub Desktop.
Soft sum to zero constraint breaks ADVI
library(cmdstanr)
model_code <- "
data {
int N;
vector[N] y;
int K;
int<lower=1, upper=K> groups[N];
}
parameters {
vector[K] z;
real<lower=0> tau;
real<lower=0> sigma;
}
transformed parameters {
vector[K] group_mus = z * tau;
}
model {
y ~ normal(group_mus[groups], sigma);
sigma ~ normal(0, 1);
tau ~ normal(0, 1);
z ~ normal(0, 1);
sum(group_mus) ~ normal(0, 0.01 * K);
}
"
m <- cmdstan_model(write_stan_file(model_code))
set.seed(8963275)
N <- 40
K <- 10
groups <- rep(1:K, length.out = N)
mu <- sort(rnorm(2, 0, 5))
sigma <- abs(rnorm(1, 0, 1))
tau <- abs(rnorm(1, 0, 1))
group_mus <- rnorm(K, sd = tau)
group_mus <- group_mus - sum(group_mus) / K
cat("True group_mus: ", paste0(group_mus, collapse = "; "), ", tau: ", tau, ", sigma: ", sigma, "\n")
y <- rnorm(N, mean = group_mus[groups], sd = sigma)
dd <- list(N = N, y = y, K = K, groups = groups)
res_sampling <- m$sample(data = dd, refresh = 0, parallel_chains = 4, adapt_delta = 0.95)
res_advi <- m$variational(data = dd)
res_sampling$summary()[,c("variable", "q5", "q95", "ess_bulk")]
res_advi$summary()[,c("variable", "q5", "q95")]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment