Skip to content

Instantly share code, notes, and snippets.

Last active April 3, 2024 07:14
Show Gist options
  • Save brshallo/3db2cd25172899f91b196a90d5980690 to your computer and use it in GitHub Desktop.
Save brshallo/3db2cd25172899f91b196a90d5980690 to your computer and use it in GitHub Desktop.
Prep interval and then produce prediction interval on a new data set. Not confident these are set-up correctly... see thread:
# Control function used as part of `prep_interval()`
ctrl_fit_recipe <- function(x){
list(fit = workflows::pull_workflow_fit(x),
recipe = workflows::pull_workflow_prepped_recipe(x)
#' extract parts of control function and output as dataframe with list-cols
#' @param wf_extracts A resample_results object with a column for .extracts
#' that was created using `ctrl_fit_recipe()`
extract_extracts <- function(wf_extracts){
wf_extracts %>%
select(.extracts) %>%
transmute(.extracts = map(.extracts, ".extracts") %>% map(1)) %>%
unnest_wider(col = .extracts, strict = TRUE)
#' Prep Interval
#' This function takes in a workflow and outputs a named list meant ot be passed
#' into `predict_interval()`.
#' @param wf Workflow containing a recipe and a model.
#' @param train Training data.
#' @param n_boot Number of bootstrap samples used to simulate building the model
#' (used for uncertainty due to model).
#' @param n_cv Number of folds for k-fold cross-validation (used for uncertainty
#' due to sample). If is NULL, will use out-of-bag predictions from
#' bootstrapped samples used to build-models.
#' @retrun A named list of two tibbles named `model_uncertainty` and
#' `sample_uncertainty`.
prep_interval <- function(wf, train, n_boot = sqrt(nrow(train)), n_cv = 10){
##### Uncertainty due to model specification #####
ctrl <- control_resamples(extract = ctrl_fit_recipe, save_pred = TRUE)
resamples_boot <- rsample::bootstraps(train, n_boot)
wf_model_uncertainty <- wf %>%
fit_resamples(resamples_boot, control = ctrl)
model_uncertainty <- extract_extracts(wf_model_uncertainty)
##### Uncertainty due to sample ######
wf_sample_uncertainty <- wf_model_uncertainty
} else{
resamples_cv <- rsample::vfold_cv(train, 10)
ctrl_cv <- control_resamples(save_pred = TRUE)
wf_sample_uncertainty <- wf %>%
fit_resamples(resamples_cv, control = ctrl_cv)
sample_uncertainty <- wf_sample_uncertainty %>%
select(.predictions) %>%
unnest(.predictions) %>%
mutate(.resid = .pred - cur_data()[[3]]) %>%
model_uncertainty = model_uncertainty,
sample_uncertainty = sample_uncertainty
#' Predict Interval
#' This function takes in the output from `predict_interval()` along with an
#' unprepped hold-out dataset and outputs a prediction interval.
#' @param prepped_interval Object outputted by `prep_interval()`.
#' @param new_data Data to generate predictions on.
#' @param probs Quantiles of predictions to output.
#' @param cross If TRUE makes distribution for selecting quantiles from made up
#' of all possible combinations of samples of {model fitting uncertainty} and
#' {sample uncertainty}. If FALSE, create `n_sims` number of simulations for
#' each sample.
#' @param n_sims Number of simulations for each observation (if `cross` == TRUE,
#' is ignored).
#' @retrun A tibble containing columns `probs_*` at each quantile specified by `probs`.
predict_interval <- function(prepped_interval, new_data, probs = c(0.025, 0.50, 0.975), cross = TRUE, n_sims = 10000){
model_uncertainty <- prepped_interval$model_uncertainty %>%
mutate(assessment = map2(fit, recipe,
~predict(.x, bake(.y, new_data = new_data)) %>%
mutate(.id = row_number()))
) %>%
select(assessment) %>%
unnest(assessment) %>%
group_by(.id) %>%
mutate(m = .pred - mean(.pred)) %>%
. ~ crossing(model_uncertainty, prepped_interval$sample_uncertainty),
~ bind_cols(
slice_sample(model_uncertainty, n = n_sims * nrow(new_data), replace = TRUE),
slice_sample(prepped_interval$sample_uncertainty, n = n_sims * nrow(new_data), replace = TRUE)
) %>%
mutate(c = m + .resid + .pred) %>%
group_by(.id) %>%
summarise(qs = quantile(c, probs),
probs = format(probs, nsmall = 2)) %>%
ungroup() %>%
pivot_wider(names_from = probs,
values_from = qs,
names_prefix = "probs_") %>%
Copy link

A prior version of this gist had a bug in purrr::when() step such that even if cross = TRUE the sim based approach was used.

Copy link

brshallo commented Apr 3, 2024

Updated this today so it should now work (had been a small change to unnest_wider() that caused it to error).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment