library(tidyverse)
library(tidymodels)
library(AmesHousing)
ames <- make_ames() %>%
mutate(Years_Old = Year_Sold - Year_Built,
Years_Old = ifelse(Years_Old < 0, 0, Years_Old))
# If were real example, would first split into train / test split
set.seed(123)
train_valid <- rsample::validation_split(ames)
# random forest model recipe for a few inputs
rf_recipe <-
recipe(
Sale_Price ~ Lot_Area + Neighborhood + Years_Old + Gr_Liv_Area + Overall_Qual + Total_Bsmt_SF + Garage_Area,
data = ames
) %>%
step_log(Sale_Price, base = 10) %>%
step_other(Neighborhood, Overall_Qual, threshold = 50) %>%
step_novel(Neighborhood, Overall_Qual) %>%
step_dummy(Neighborhood, Overall_Qual)
rf_mod <- rand_forest(min_n = tune(),) %>%
set_engine("ranger", importance = "impurity", seed = 63233, quantreg = TRUE) %>%
set_mode("regression")
grid_n <- tibble(min_n = 2^(1:8))
analysis_preds <- function(x){
bind_cols(
pull_workflow_mold(x)$outcomes %>% set_names(".outcome"),
predict(pull_workflow_fit(x), pull_workflow_mold(x)$predictors)
)
}
ctrl <- control_resamples(extract = analysis_preds)
reg_mets <- metric_set(rmse, mae, mape, rsq)
set.seed(63233)
rf_wf <- workflows::workflow() %>%
add_model(rf_mod) %>%
add_recipe(rf_recipe) %>%
tune_grid(grid = grid_n,
resamples = train_valid,
control = ctrl,
metrics = reg_mets)
# # see outcomes and predictions on analysis data:
# rf_wf$.extracts %>%
# map(".extracts")
extract_extracts <- function(x_extracts, suffix = ""){
x_extracts %>%
mutate(metrics_analysis = map(.extracts, ~reg_mets(.x, .outcome, .pred))) %>%
select(min_n, metrics_analysis) %>%
unnest(metrics_analysis) %>%
rename_with(~paste0(., suffix))
}
# performance on assessment and analysis sets
perf <- rf_wf %>%
mutate(.extracts = map(.extracts, extract_extracts, suffix = "_analysis")) %>%
select(id, .metrics, .extracts) %>%
unnest(.metrics, .extracts) %>%
select(everything(), -contains("_analysis"), .estimate_analysis)
perf %>%
filter(.metric == "rmse") %>%
pivot_longer(cols = contains("estimate")) %>%
mutate(dataset = forcats::fct_recode(name,
assessment = ".estimate",
analysis = ".estimate_analysis")) %>%
ggplot(aes(x = min_n))+
geom_line(aes(y = value, colour = dataset))+
theme_bw()+
labs(title = "Performance on Analysis / Assessment Sets",
subtitle = "While tuning `min_n`",
y = "RMSE")
Created on 2021-03-07 by the reprex package (v0.3.0)
Code above shows way to extract metrics in analysis set(s) of resampling object @kbzsl tidymodels/tune#215 . A potential use case (used in this example) may be to review the extent of better fit between the model training data and a hold-out set when tuning on some parameter.