Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
How to extract performance on both *analysis* and assessment sets (e.g. in this case to compare performance on a hold-out set while tuning some hyperparameter).
library(tidyverse)
library(tidymodels)
library(AmesHousing)

ames <- make_ames() %>% 
  mutate(Years_Old = Year_Sold - Year_Built,
         Years_Old = ifelse(Years_Old < 0, 0, Years_Old))

# If were real example, would first split into train / test split
set.seed(123)
train_valid <- rsample::validation_split(ames)

# random forest model recipe for a few inputs
rf_recipe <- 
  recipe(
    Sale_Price ~ Lot_Area + Neighborhood  + Years_Old + Gr_Liv_Area + Overall_Qual + Total_Bsmt_SF + Garage_Area, 
    data = ames
  ) %>%
  step_log(Sale_Price, base = 10) %>%
  step_other(Neighborhood, Overall_Qual, threshold = 50) %>% 
  step_novel(Neighborhood, Overall_Qual) %>% 
  step_dummy(Neighborhood, Overall_Qual) 

rf_mod <- rand_forest(min_n = tune(),) %>% 
  set_engine("ranger", importance = "impurity", seed = 63233, quantreg = TRUE) %>% 
  set_mode("regression")

grid_n <- tibble(min_n = 2^(1:8))

analysis_preds <- function(x){
  bind_cols(
    pull_workflow_mold(x)$outcomes %>% set_names(".outcome"),
    predict(pull_workflow_fit(x), pull_workflow_mold(x)$predictors)
  )
}

ctrl <- control_resamples(extract = analysis_preds)
reg_mets <- metric_set(rmse, mae, mape, rsq)

set.seed(63233)
rf_wf <- workflows::workflow() %>% 
  add_model(rf_mod) %>% 
  add_recipe(rf_recipe) %>% 
  tune_grid(grid = grid_n, 
            resamples = train_valid,
            control = ctrl,
            metrics = reg_mets)

# # see outcomes and predictions on analysis data:
#  rf_wf$.extracts %>%
#    map(".extracts")

extract_extracts <- function(x_extracts, suffix = ""){
  x_extracts %>% 
    mutate(metrics_analysis = map(.extracts, ~reg_mets(.x, .outcome, .pred))) %>%
    select(min_n, metrics_analysis) %>% 
    unnest(metrics_analysis) %>% 
    rename_with(~paste0(., suffix))
}

# performance on assessment and analysis sets
perf <- rf_wf %>% 
  mutate(.extracts = map(.extracts, extract_extracts, suffix = "_analysis")) %>% 
  select(id, .metrics, .extracts) %>% 
  unnest(.metrics, .extracts) %>% 
  select(everything(), -contains("_analysis"), .estimate_analysis)
 
perf %>% 
  filter(.metric == "rmse") %>% 
  pivot_longer(cols = contains("estimate")) %>% 
  mutate(dataset = forcats::fct_recode(name, 
                          assessment = ".estimate",
                          analysis = ".estimate_analysis")) %>% 
  ggplot(aes(x = min_n))+
  geom_line(aes(y = value, colour = dataset))+
  theme_bw()+
  labs(title = "Performance on Analysis / Assessment Sets",
       subtitle = "While tuning `min_n`",
       y = "RMSE")

Created on 2021-03-07 by the reprex package (v0.3.0)

@brshallo

This comment has been minimized.

Copy link
Owner Author

@brshallo brshallo commented Mar 8, 2021

Code above shows way to extract metrics in analysis set(s) of resampling object @kbzsl tidymodels/tune#215 . A potential use case (used in this example) may be to review the extent of better fit between the model training data and a hold-out set when tuning on some parameter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment