Skip to content

Instantly share code, notes, and snippets.

@brshallo
Last active July 26, 2021 17:33
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save brshallo/4053df78265ab9d77f753d95f5faaf5b to your computer and use it in GitHub Desktop.
Prep interval and then produce prediction interval on a new data set. See thread: https://community.rstudio.com/t/prediction-intervals-with-tidymodels-best-practices/82594/15 also see prior set-up: https://gist.github.com/brshallo/3db2cd25172899f91b196a90d5980690 . The approach at this gist is similar but uses the bootstrapped residuals to produ…
library(tidyverse)
library(tidymodels)
# Control function used as part of `prep_interval()`
ctrl_fit_recipe <- function(x){
output <- list(fit = workflows::pull_workflow_fit(x),
recipe = workflows::pull_workflow_prepped_recipe(x))
c(output, list(resids =
bind_cols(
pull_workflow_mold(x)$outcomes %>% set_names(".outcome"),
predict(output$fit, pull_workflow_mold(x)$predictors)
)
))
}
#' extract parts of control function and output as dataframe with list-cols
#' @param wf_extracts A resample_results object with a column for .extracts
#' that was created using `ctrl_fit_recipe()`
extract_extracts <- function(wf_extracts){
wf_extracts %>%
select(.extracts) %>%
transmute(.extracts = map(.extracts, ".extracts") %>% map(1)) %>%
unnest_wider(.extracts)
}
#' Prep Interval
#'
#' This function takes in a workflow and outputs a named list meant ot be passed
#' into `predict_interval()`.
#"
#' @param wf Workflow containing a recipe and a model.
#' @param train Training data.
#' @param n_boot Number of bootstrap samples used to simulate building the model
#' (used for uncertainty due to model).
#' @retrun A named list of two tibbles named `model_uncertainty` and
#' `sample_uncertainty`.
prep_interval <- function(wf, train, n_boot = sqrt(nrow(train))){
##### Uncertainty due to model specification #####
ctrl <- control_resamples(extract = ctrl_fit_recipe, save_pred = TRUE)
# Could have where you input a resample specification...
resamples_boot <- rsample::bootstraps(train, n_boot)
wf_model_uncertainty <- wf %>%
fit_resamples(resamples_boot, control = ctrl)
model_uncertainty <- extract_extracts(wf_model_uncertainty)
##### Uncertainty due to sample #####
# Am setting-up weighting here between residuals on analysis versus assessment sets
analysis_preds <- model_uncertainty %>%
select(resids) %>%
unnest()
no_info <- select(analysis_preds, .outcome) %>%
bind_cols(analysis_preds[sample(nrow(analysis_preds)), ".pred"]) %>%
transmute(.resid = .pred - .outcome)
analysis_resids <- analysis_preds %>%
transmute(.resid = .pred - .outcome)
assessment_resids <- wf_model_uncertainty %>%
select(.predictions) %>%
unnest(.predictions) %>%
mutate(.resid = .pred - cur_data()[[3]]) %>%
select(.resid)
R <- (mean(abs(assessment_resids$.resid)) - mean(abs(analysis_resids$.resid))) / (mean(abs(no_info$.resid)) - mean(abs(analysis_resids$.resid)))
W <- 0.632 / (1 - 0.368 * R)
quantiles <- seq(from = 0, to = 1, length.out = nrow(train))
quantiles <- quantiles[c(-1, -length(quantiles))]
sample_uncertainty <- (1 - W) * quantile(analysis_resids$.resid, probs = quantiles) + W * quantile(assessment_resids$.resid, probs = quantiles)
sample_uncertainty <- tibble(.resid = sample_uncertainty)
# Outputs
model_uncertainty <- model_uncertainty %>%
select(-resids)
list(
model_uncertainty = model_uncertainty,
sample_uncertainty = sample_uncertainty
)
}
#' Predict Interval
#'
#' This function takes in the output from `predict_interval()` along with an
#' unprepped hold-out dataset and outputs a prediction interval.
#'
#' @param prepped_interval Object outputted by `prep_interval()`.
#' @param new_data Data to generate predictions on.
#' @param probs Quantiles of predictions to output.
#' @param cross If TRUE makes distribution for selecting quantiles from made up
#' of all possible combinations of samples of {model fitting uncertainty} and
#' {sample uncertainty}. If FALSE, create `n_sims` number of simulations for
#' each sample.
#' @param n_sims Number of simulations for each observation (if `cross` == TRUE,
#' is ignored).
#' @retrun A tibble containing columns `probs_*` at each quantile specified by `probs`.
predict_interval <- function(prepped_interval, new_data, probs = c(0.025, 0.50, 0.975), cross = TRUE, n_sims = 10000){
model_uncertainty <- prepped_interval$model_uncertainty %>%
mutate(assessment = map2(fit, recipe,
~predict(.x, bake(.y, new_data = new_data)) %>%
mutate(.id = row_number()))
) %>%
select(assessment) %>%
unnest(assessment) %>%
group_by(.id) %>%
mutate(m = .pred - mean(.pred)) %>%
ungroup()
purrr::when(cross,
. ~ crossing(model_uncertainty, prepped_interval$sample_uncertainty),
~ bind_cols(
slice_sample(model_uncertainty, n = n_sims * nrow(new_data), replace = TRUE),
slice_sample(prepped_interval$sample_uncertainty, n = n_sims * nrow(new_data), replace = TRUE)
)
) %>%
mutate(c = m + .resid + .pred) %>%
group_by(.id) %>%
summarise(qs = quantile(c, probs),
probs = format(probs, nsmall = 2)) %>%
ungroup() %>%
pivot_wider(names_from = probs,
values_from = qs,
names_prefix = "probs_") %>%
select(-.id)
}
@brshallo
Copy link
Author

A prior version of this gist had a bug in purrr::when() step such that even if cross = TRUE the sim based approach was used.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment