Skip to content

Instantly share code, notes, and snippets.

@RobWiederstein
Created May 30, 2021 19:00
Show Gist options
  • Save RobWiederstein/a85731e45d6a88cf3496c25b55cf97e4 to your computer and use it in GitHub Desktop.
Save RobWiederstein/a85731e45d6a88cf3496c25b55cf97e4 to your computer and use it in GitHub Desktop.
fifth tutorial from tidymodels
#################################################################
## tidymodels ##
## 5 A predictive modeling case study ##
## url: https://www.tidymodels.org/start/case-study/ ##
#################################################################
# 1.0 INTRODUCTION ----
## 1.1 General ----
library(tidymodels)
## 1.2 Helper packages ----
library(readr) # for importing data
library(vip) # for variable importance plots
# 2.0 HOTEL BOOKINGS DATA -- STAYS ONLY -- PREDICT CHILDREN ----
## 2.1 Read-in ----
hotels <-
read_csv('https://tidymodels.org/start/case-study/hotels.csv') %>%
mutate_if(is.character, as.factor)
## 2.2 View ----
glimpse(hotels)
## 2.3 Outcome variable ----
hotels %>%
count(children) %>%
mutate(prop = n/sum(n))
# 8.3% of reservations
# 3.0 DATA SPLITTING & RESAMPLING ----
## 3.1 Split into stratified random sample----
set.seed(123)
splits <- initial_split(hotels, strata = children)
hotel_other <- training(splits)
hotel_test <- testing(splits)
## 3.2 training set proportions by children ----
hotel_other %>%
count(children) %>%
mutate(prop = n/sum(n))
## 3.3 # test set proportions by children ----
hotel_test %>%
count(children) %>%
mutate(prop = n/sum(n))
## 3.4 validation_split() ----
set.seed(234)
val_set <- validation_split(hotel_other,
strata = children,
prop = 0.80)
# 4.0 FIRST MODEL: PENALIZED LOGISTIC REGRESSION ----
## 4.1 Build the model ----
# tune() as placeholder
# mixture = 1 removes irrelevant predictors
lr_mod <-
logistic_reg(penalty = tune(), mixture = 1) %>%
set_engine("glmnet")
## 4.2 Create recipe ----
holidays <- c("AllSouls", "AshWednesday", "ChristmasEve", "Easter",
"ChristmasDay", "GoodFriday", "NewYearsDay", "PalmSunday")
lr_recipe <-
recipe(children ~ ., data = hotel_other) %>%
step_date(arrival_date) %>%
step_holiday(arrival_date, holidays = holidays) %>%
step_rm(arrival_date) %>%
step_dummy(all_nominal(), -all_outcomes()) %>%
step_zv(all_predictors()) %>%
step_normalize(all_predictors())
## 4.3 Create workflow
lr_workflow <-
workflow() %>%
add_model(lr_mod) %>%
add_recipe(lr_recipe)
## 4.4 Create grid
lr_reg_grid <- tibble(penalty = 10^seq(-4, -1, length.out = 30))
lr_reg_grid %>% top_n(-5) # lowest penalty values
lr_reg_grid %>% top_n(5) # highest penalty values
## 4.5 Train & Tune ----
lr_res <-
lr_workflow %>%
tune_grid(val_set,
grid = lr_reg_grid,
control = control_grid(save_pred = TRUE),
metrics = metric_set(roc_auc))
### 4.5.1 Plot ----
lr_plot <-
lr_res %>%
collect_metrics() %>%
ggplot(aes(x = penalty, y = mean)) +
geom_point() +
geom_line() +
ylab("Area under the ROC Curve") +
scale_x_log10(labels = scales::label_number())
lr_plot
# 5.0 SECOND MODEL: TREE-BASED ENSEMBLE ----
# An effective and low-maintenance modeling technique is a random forest.
# Tree-based models require very little preprocessing and can handle many types
# of predictors (sparse, skewed, continuous, categorical, etc.).
## 5.1 Build model reduce training time ----
# The tune package can do parallel processing for you, and allows users
# to use multiple cores or separate machines to fit models.
### 5.1.1 Detect cores ----
cores <- parallel::detectCores()
cores
### 5.1.2 Build model ----
rf_mod <-
rand_forest(mtry = tune(), min_n = tune(), trees = 1000) %>%
# tune() is placeholder for later
set_engine("ranger", num.threads = cores) %>%
set_mode("classification")
### 5.1.3 CAUTION: Don't set cores except for random forest ----
## 5.2 Create Recipe ----
#Unlike penalized logistic regression models, random forest models do
#not require dummy or normalized predictor variables.
rf_recipe <-
recipe(children ~ ., data = hotel_other) %>%
step_date(arrival_date) %>%
step_holiday(arrival_date) %>%
step_rm(arrival_date)
## 5.3 Create Workflow ----
rf_workflow <-
workflow() %>%
add_model(rf_mod) %>%
add_recipe(rf_recipe)
## 5.4 Train and Tune Model ----
### 5.4.1 Show what will be tuned
rf_mod %>%
parameters()
### 5.4.2 space-filling grid ----
set.seed(345)
rf_res <-
rf_workflow %>%
tune_grid(val_set,
grid = 25,
control = control_grid(save_pred = TRUE),
metrics = metric_set(roc_auc))
### 5.4.3 Show the best ----
rf_res %>%
show_best(metric = "roc_auc")
### 5.4.4 Plot
#However, the range of the y-axis indicates that the model is
#very robust to the choice of these parameter values — all but
#one of the ROC AUC values are greater than 0.90.
autoplot(rf_res)
### 5.4.5 Select best ----
rf_best <-
rf_res %>%
select_best(metric = "roc_auc")
rf_best
### 5.4.6 Filter model to best prediction ----
rf_auc <-
rf_res %>%
collect_predictions(parameters = rf_best) %>%
roc_curve(children, .pred_children) %>%
mutate(model = "Random Forest")
### 5.4.7 Plot best model ----
bind_rows(rf_auc, lr_auc) %>%
ggplot(aes(x = 1 - specificity, y = sensitivity, col = model)) +
geom_path(lwd = 1.5, alpha = 0.8) +
geom_abline(lty = 3) +
coord_equal() +
scale_color_viridis_d(option = "plasma", end = .6)
# error lr_auc was in previous lesson
# 6.0 THE LAST FIT ----
#build parsnip model object again from scratch
#take our best hyperparameter values from our random forest model.
# set new argument: importance = "impurity"
## 6.1 last model ----
last_rf_mod <-
rand_forest(mtry = 8, min_n = 7, trees = 1000) %>%
set_engine("ranger", num.threads = cores, importance = "impurity") %>%
set_mode("classification")
## 6.2 last workflow ----
last_rf_workflow <-
rf_workflow %>%
update_model(last_rf_mod)
## 6.3 last fit ----
set.seed(345)
last_rf_fit <-
last_rf_workflow %>%
last_fit(splits)
## 6.4 evaluate model ----
last_rf_fit %>%
collect_metrics()
## 6.5 review variable importance ----
last_rf_fit %>%
pluck(".workflow", 1) %>%
pull_workflow_fit() %>%
vip(num_features = 20)
## 6.6 last roc ----
#similar to validation set. good predictor on new data.
last_rf_fit %>%
collect_predictions() %>%
roc_curve(children, .pred_children) %>%
autoplot()
# 7.0 RESOURCES ----
# Kuhn & Silge: https://www.tmwr.org
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment