Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
How to extract performance on both *analysis* and assessment sets (e.g. in this case to compare performance on a hold-out set while tuning some hyperparameter).
library(tidyverse)
library(tidymodels)
library(AmesHousing)

ames <- make_ames() %>% 
  mutate(Years_Old = Year_Sold - Year_Built,
         Years_Old = ifelse(Years_Old < 0, 0, Years_Old))

# If were real example, would first split into train / test split
set.seed(123)
train_valid <- rsample::validation_split(ames)

# random forest model recipe for a few inputs
rf_recipe <- 
  recipe(
    Sale_Price ~ Lot_Area + Neighborhood  + Years_Old + Gr_Liv_Area + Overall_Qual + Total_Bsmt_SF + Garage_Area, 
    data = ames
  ) %>%
  step_log(Sale_Price, base = 10) %>%
  step_other(Neighborhood, Overall_Qual, threshold = 50) %>% 
  step_novel(Neighborhood, Overall_Qual) %>% 
  step_dummy(Neighborhood, Overall_Qual) 

rf_mod <- rand_forest(min_n = tune(),) %>% 
  set_engine("ranger", importance = "impurity", seed = 63233, quantreg = TRUE) %>% 
  set_mode("regression")

grid_n <- tibble(min_n = 2^(1:8))

analysis_preds <- function(x){
  bind_cols(
    pull_workflow_mold(x)$outcomes %>% set_names(".outcome"),
    predict(pull_workflow_fit(x), pull_workflow_mold(x)$predictors)
  )
}

ctrl <- control_resamples(extract = analysis_preds)
reg_mets <- metric_set(rmse, mae, mape, rsq)

set.seed(63233)
rf_wf <- workflows::workflow() %>% 
  add_model(rf_mod) %>% 
  add_recipe(rf_recipe) %>% 
  tune_grid(grid = grid_n, 
            resamples = train_valid,
            control = ctrl,
            metrics = reg_mets)

# # see outcomes and predictions on analysis data:
#  rf_wf$.extracts %>%
#    map(".extracts")

extract_extracts <- function(x_extracts, suffix = ""){
  x_extracts %>% 
    mutate(metrics_analysis = map(.extracts, ~reg_mets(.x, .outcome, .pred))) %>%
    select(min_n, metrics_analysis) %>% 
    unnest(metrics_analysis) %>% 
    rename_with(~paste0(., suffix))
}

# performance on assessment and analysis sets
perf <- rf_wf %>% 
  mutate(.extracts = map(.extracts, extract_extracts, suffix = "_analysis")) %>% 
  select(id, .metrics, .extracts) %>% 
  unnest(.metrics, .extracts) %>% 
  select(everything(), -contains("_analysis"), .estimate_analysis)
 
perf %>% 
  filter(.metric == "rmse") %>% 
  pivot_longer(cols = contains("estimate")) %>% 
  mutate(dataset = forcats::fct_recode(name, 
                          assessment = ".estimate",
                          analysis = ".estimate_analysis")) %>% 
  ggplot(aes(x = min_n))+
  geom_line(aes(y = value, colour = dataset))+
  theme_bw()+
  labs(title = "Performance on Analysis / Assessment Sets",
       subtitle = "While tuning `min_n`",
       y = "RMSE")

Created on 2021-03-07 by the reprex package (v0.3.0)

@brshallo

This comment has been minimized.

Copy link
Owner Author

@brshallo brshallo commented Mar 8, 2021

Code above shows way to extract metrics in analysis set(s) of resampling object @kbzsl tidymodels/tune#215 . A potential use case (used in this example) may be to review the extent of better fit between the model training data and a hold-out set when tuning on some parameter.

@brshallo

This comment has been minimized.

Copy link
Owner Author

@brshallo brshallo commented May 8, 2021

Example used in Cautions with Overfitting section of my post on Understanding Prediction Intervals

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment