Skip to content

Instantly share code, notes, and snippets.

@brshallo
Last active May 8, 2021 20:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save brshallo/516e8e52fb911a96efafbf01d606a113 to your computer and use it in GitHub Desktop.
Save brshallo/516e8e52fb911a96efafbf01d606a113 to your computer and use it in GitHub Desktop.
How to extract performance on both *analysis* and assessment sets (e.g. in this case to compare performance on a hold-out set while tuning some hyperparameter).
library(tidyverse)
library(tidymodels)
library(AmesHousing)

ames <- make_ames() %>% 
  mutate(Years_Old = Year_Sold - Year_Built,
         Years_Old = ifelse(Years_Old < 0, 0, Years_Old))

# If were real example, would first split into train / test split
set.seed(123)
train_valid <- rsample::validation_split(ames)

# random forest model recipe for a few inputs
rf_recipe <- 
  recipe(
    Sale_Price ~ Lot_Area + Neighborhood  + Years_Old + Gr_Liv_Area + Overall_Qual + Total_Bsmt_SF + Garage_Area, 
    data = ames
  ) %>%
  step_log(Sale_Price, base = 10) %>%
  step_other(Neighborhood, Overall_Qual, threshold = 50) %>% 
  step_novel(Neighborhood, Overall_Qual) %>% 
  step_dummy(Neighborhood, Overall_Qual) 

rf_mod <- rand_forest(min_n = tune(),) %>% 
  set_engine("ranger", importance = "impurity", seed = 63233, quantreg = TRUE) %>% 
  set_mode("regression")

grid_n <- tibble(min_n = 2^(1:8))

analysis_preds <- function(x){
  bind_cols(
    pull_workflow_mold(x)$outcomes %>% set_names(".outcome"),
    predict(pull_workflow_fit(x), pull_workflow_mold(x)$predictors)
  )
}

ctrl <- control_resamples(extract = analysis_preds)
reg_mets <- metric_set(rmse, mae, mape, rsq)

set.seed(63233)
rf_wf <- workflows::workflow() %>% 
  add_model(rf_mod) %>% 
  add_recipe(rf_recipe) %>% 
  tune_grid(grid = grid_n, 
            resamples = train_valid,
            control = ctrl,
            metrics = reg_mets)

# # see outcomes and predictions on analysis data:
#  rf_wf$.extracts %>%
#    map(".extracts")

extract_extracts <- function(x_extracts, suffix = ""){
  x_extracts %>% 
    mutate(metrics_analysis = map(.extracts, ~reg_mets(.x, .outcome, .pred))) %>%
    select(min_n, metrics_analysis) %>% 
    unnest(metrics_analysis) %>% 
    rename_with(~paste0(., suffix))
}

# performance on assessment and analysis sets
perf <- rf_wf %>% 
  mutate(.extracts = map(.extracts, extract_extracts, suffix = "_analysis")) %>% 
  select(id, .metrics, .extracts) %>% 
  unnest(.metrics, .extracts) %>% 
  select(everything(), -contains("_analysis"), .estimate_analysis)
 
perf %>% 
  filter(.metric == "rmse") %>% 
  pivot_longer(cols = contains("estimate")) %>% 
  mutate(dataset = forcats::fct_recode(name, 
                          assessment = ".estimate",
                          analysis = ".estimate_analysis")) %>% 
  ggplot(aes(x = min_n))+
  geom_line(aes(y = value, colour = dataset))+
  theme_bw()+
  labs(title = "Performance on Analysis / Assessment Sets",
       subtitle = "While tuning `min_n`",
       y = "RMSE")

Created on 2021-03-07 by the reprex package (v0.3.0)

@brshallo
Copy link
Author

brshallo commented Mar 8, 2021

Code above shows way to extract metrics in analysis set(s) of resampling object @kbzsl tidymodels/tune#215 . A potential use case (used in this example) may be to review the extent of better fit between the model training data and a hold-out set when tuning on some parameter.

@brshallo
Copy link
Author

brshallo commented May 8, 2021

Example used in Cautions with Overfitting section of my post on Understanding Prediction Intervals

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment