Skip to content

Instantly share code, notes, and snippets.

@bnicenboim
Last active October 28, 2020 14:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bnicenboim/5796795acf164f0a0f163401e5975bfd to your computer and use it in GitHub Desktop.
Save bnicenboim/5796795acf164f0a0f163401e5975bfd to your computer and use it in GitHub Desktop.
sbc for cmdstanr
library(dplyr)
sbc_cmd <- function (cmdstanrmodel, data, M, ...)
{
stan_code <- cmdstanrmodel$code()
stan_code <- scan(what = character(), sep = "\n", quiet = TRUE,
text = stan_code)
pars_lines <- grep("[[:space:]]*(pars_)|(pars_\\[.*\\])[[:space:]]*=",
stan_code, value = TRUE)
if(length(pars_lines)==1){
# it was defined and assigned in one line:
pars_lines <- gsub(".*?=.*?(\\{|\\[)(.*?)(\\]|\\}).*","\\2",pars_lines)
pars_names <- trimws(strsplit(pars_lines, split = ",")[[1]])
pars_names <- unique(sub("^([a-z,A-Z,0-9,_]*)_.*", "\\1",
pars_names))
} else {
pars_lines <- pars_lines[!grepl("^[[:space:]]*vector", pars_lines) &
!grepl("^[[:space:]]*real", pars_lines)]
pars_names <- trimws(sapply(strsplit(pars_lines, split = "=",
fixed = TRUE), tail, n = 1))
pars_names <- unique(sub("^([a-z,A-Z,0-9,_]*)_.*;", "\\1",
pars_names))
}
noUnderscore <- grepl(";", pars_names, fixed = TRUE)
if (any(noUnderscore)) {
warning(paste("The following parameters were added to pars_ but did not",
"have the expected underscore postfix:", paste(pars_names[noUnderscore],
collapse = " ")))
return()
}
has_log_lik <- any(grepl("log_lik[[:space:]]*;[[:space:]]*",
stan_code))
todo <- as.integer(seq(from = 1, to = .Machine$integer.max,
length.out = M))
post <- parallel::mclapply(todo, FUN = function(S) {
message("\nRun ", which(todo ==S)," out of ", M," ...\n")
out <- cmdstanrmodel$sample(data,
chains = 1L,
parallel_chains = 1L,
seed = S,
chain_ids = S,
refresh = 0,
thin = 1L, ...)
Y <- out$draws("y_") %>%
posterior::as_draws_df() %>%
colMeans()
Y <- Y[!grepl("\\.",names(Y))]
pars <- out$draws("pars_") %>%
posterior::as_draws_df() %>%
colMeans()
pars <- pars[!grepl("\\.",names(pars))]
names(pars) <- pars_names
ranks <- out$draws("ranks_") %>%
posterior::as_draws_df() %>%
as.matrix()
ranks <- ranks[,!grepl("\\.",colnames(ranks))]
colnames(ranks) <- pars_names
log_lik <- out$draws("log_lik") %>%
posterior::as_draws_df() %>%
as.matrix()
log_lik <- log_lik[,!grepl("\\.",colnames(log_lik))]
sampler_params <- out$sampler_diagnostics() %>%
posterior::as_draws_df() %>%
as.matrix()
list(Y=Y, pars = pars, ranks = ranks, log_lik =log_lik, sampler_params = sampler_params)
})
Y <- sapply(post, function(x) x[["Y"]])
pars <- sapply(post, function(x) x[["pars"]])
ranks <- lapply(post, function(x) x[["ranks"]])
sampler_params <- simplify2array( lapply(post, function(x) x[["sampler_params"]]))
if (has_log_lik)
pareto_k <- sapply(post, FUN = function(x) suppressWarnings(loo::loo(x$log_lik))$diagnostics$pareto_k)
out <- list(ranks = ranks, Y = Y, pars = pars, sampler_params = sampler_params,
pareto_k = if (has_log_lik) pareto_k)
class(out) <- "sbc"
return(out)
}
plot_sbc <- function (x, thin = 3, binwidth =50, ...){
thinner <- seq(from = 1, to = nrow(x$ranks[[1]]), by = thin)
# before making it appropriate for the binwidth
max_rank <- length(thinner) +1
excess <- max_rank %% binwidth
max_rank <- max_rank - excess
u <- t(sapply(x$ranks, FUN = function(r) {
thinned_ranks <- r[thinner, , drop = FALSE]
if(excess !=0 ){
thinned_ranks <- thinned_ranks[-seq_len(excess),, drop = FALSE]
}
1L + colSums(thinned_ranks)
}))
parameter <- as.factor(rep(colnames(u), each = nrow(u)))
d <- data.frame(u = c(u), parameter)
bands <-
qbinom(c(0.005,.5,.995), dim(u)[1], binwidth/max_rank)
ggplot2::ggplot(d) +
ggplot2::geom_histogram(ggplot2::aes(x = u),
..., binwidth = binwidth, color = "black",
fill = "#ffffe8", boundary = 0) + ggplot2::facet_wrap("parameter") +
ggplot2::geom_hline(yintercept = bands, color = "black", linetype = "dashed")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment