Skip to content

Instantly share code, notes, and snippets.

@bbolker
Forked from derekpowell/cbrm.R
Created May 22, 2023 20:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bbolker/eb7e0e1a5c6e21b3c3f24d536e804b8e to your computer and use it in GitHub Desktop.
Save bbolker/eb7e0e1a5c6e21b3c3f24d536e804b8e to your computer and use it in GitHub Desktop.
Wrapper for brm() that supports caching of BRMS models
cbrm <- function(formula,
data,
family = gaussian(),
prior = NULL,
autocor = NULL,
cov_ranef = NULL,
sample_prior = c("no", "yes", "only"),
sparse = FALSE,
knots = NULL,
stan_funs = NULL,
fit = NA,
save_ranef = TRUE,
save_mevars = FALSE,
save_all_pars = FALSE,
inits = "random",
chains = 4,
iter = 2000,
warmup = floor(iter/2),
thin = 1,
cores = getOption("mc.cores", 1L),
control = NULL,
algorithm = c("sampling", "meanfield", "fullrank"),
future = getOption("future", FALSE),
silent = TRUE,
seed = NA,
save_model = NULL,
save_dso = TRUE,
file = NULL
) {
if (is.null(control)) {
control <- list(adapt_delta = .80)
}
model_exists <- FALSE
# look for existing cached model
if (!(is.null(file))) {
model_dir <- dirname(file)
cached_file <- basename(file)
full_filename <- file
# check inputs of requested model
new_model_code <- brms::make_stancode(
formula = formula,
family = family,
prior = prior,
data = data,
sample_prior = sample_prior
)
new_model_data <- brms::make_standata(
formula = formula,
family = family,
prior = prior,
data = data,
sample_prior = sample_prior
)
# check for match between inputs of loaded and requested models
if (cached_file %in% list.files(model_dir)) {
loaded_model <- readRDS(full_filename)
code_match <- identical(new_model_code, brms::stancode(loaded_model))
data_match <- identical(new_model_data, brms::standata(loaded_model))
control_match <- identical(brms::control_params(loaded_model, pars=names(control)),
control)
loaded_samples <- posterior_samples(loaded_model, add_chain=TRUE)
iter_match <- identical(iter, max(loaded_samples$iter))
warmup_match <- identical(warmup, min(loaded_samples$iter)-1)
matching_tests <- c(code_match, data_match, control_match, iter_match, warmup_match)
if (
(all(matching_tests))
) {
model_exists <- TRUE
}
}
}
if (model_exists) {
output <- loaded_model
} else {
output <- brms::brm(
formula = formula,
data = data,
family = family,
prior = prior,
autocor = autocor,
cov_ranef = cov_ranef,
sample_prior = sample_prior,
sparse = sparse,
knots = knots,
stan_funs = stan_funs,
fit = fit,
save_ranef = TRUE,
save_mevars = save_mevars,
save_all_pars = save_all_pars,
inits = inits,
chains = chains,
iter = iter,
warmup = warmup,
thin = thin,
cores = cores,
control = control,
algorithm = algorithm,
future = future,
silent = silent,
seed = seed,
save_model = save_model,
save_dso = save_dso
)
if (!(is.null(cached_file))) {
saveRDS(output, file = full_filename)
}
}
return(output)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment