Created
May 30, 2021 19:00
-
-
Save RobWiederstein/a85731e45d6a88cf3496c25b55cf97e4 to your computer and use it in GitHub Desktop.
fifth tutorial from tidymodels
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
################################################################# | |
## tidymodels ## | |
## 5 A predictive modeling case study ## | |
## url: https://www.tidymodels.org/start/case-study/ ## | |
################################################################# | |
# 1.0 INTRODUCTION ---- | |
## 1.1 General ---- | |
library(tidymodels) | |
## 1.2 Helper packages ---- | |
library(readr) # for importing data | |
library(vip) # for variable importance plots | |
# 2.0 HOTEL BOOKINGS DATA -- STAYS ONLY -- PREDICT CHILDREN ---- | |
## 2.1 Read-in ---- | |
hotels <- | |
read_csv('https://tidymodels.org/start/case-study/hotels.csv') %>% | |
mutate_if(is.character, as.factor) | |
## 2.2 View ---- | |
glimpse(hotels) | |
## 2.3 Outcome variable ---- | |
hotels %>% | |
count(children) %>% | |
mutate(prop = n/sum(n)) | |
# 8.3% of reservations | |
# 3.0 DATA SPLITTING & RESAMPLING ---- | |
## 3.1 Split into stratified random sample---- | |
set.seed(123) | |
splits <- initial_split(hotels, strata = children) | |
hotel_other <- training(splits) | |
hotel_test <- testing(splits) | |
## 3.2 training set proportions by children ---- | |
hotel_other %>% | |
count(children) %>% | |
mutate(prop = n/sum(n)) | |
## 3.3 # test set proportions by children ---- | |
hotel_test %>% | |
count(children) %>% | |
mutate(prop = n/sum(n)) | |
## 3.4 validation_split() ---- | |
set.seed(234) | |
val_set <- validation_split(hotel_other, | |
strata = children, | |
prop = 0.80) | |
# 4.0 FIRST MODEL: PENALIZED LOGISTIC REGRESSION ---- | |
## 4.1 Build the model ---- | |
# tune() as placeholder | |
# mixture = 1 removes irrelevant predictors | |
lr_mod <- | |
logistic_reg(penalty = tune(), mixture = 1) %>% | |
set_engine("glmnet") | |
## 4.2 Create recipe ---- | |
holidays <- c("AllSouls", "AshWednesday", "ChristmasEve", "Easter", | |
"ChristmasDay", "GoodFriday", "NewYearsDay", "PalmSunday") | |
lr_recipe <- | |
recipe(children ~ ., data = hotel_other) %>% | |
step_date(arrival_date) %>% | |
step_holiday(arrival_date, holidays = holidays) %>% | |
step_rm(arrival_date) %>% | |
step_dummy(all_nominal(), -all_outcomes()) %>% | |
step_zv(all_predictors()) %>% | |
step_normalize(all_predictors()) | |
## 4.3 Create workflow | |
lr_workflow <- | |
workflow() %>% | |
add_model(lr_mod) %>% | |
add_recipe(lr_recipe) | |
## 4.4 Create grid | |
lr_reg_grid <- tibble(penalty = 10^seq(-4, -1, length.out = 30)) | |
lr_reg_grid %>% top_n(-5) # lowest penalty values | |
lr_reg_grid %>% top_n(5) # highest penalty values | |
## 4.5 Train & Tune ---- | |
lr_res <- | |
lr_workflow %>% | |
tune_grid(val_set, | |
grid = lr_reg_grid, | |
control = control_grid(save_pred = TRUE), | |
metrics = metric_set(roc_auc)) | |
### 4.5.1 Plot ---- | |
lr_plot <- | |
lr_res %>% | |
collect_metrics() %>% | |
ggplot(aes(x = penalty, y = mean)) + | |
geom_point() + | |
geom_line() + | |
ylab("Area under the ROC Curve") + | |
scale_x_log10(labels = scales::label_number()) | |
lr_plot | |
# 5.0 SECOND MODEL: TREE-BASED ENSEMBLE ---- | |
# An effective and low-maintenance modeling technique is a random forest. | |
# Tree-based models require very little preprocessing and can handle many types | |
# of predictors (sparse, skewed, continuous, categorical, etc.). | |
## 5.1 Build model reduce training time ---- | |
# The tune package can do parallel processing for you, and allows users | |
# to use multiple cores or separate machines to fit models. | |
### 5.1.1 Detect cores ---- | |
cores <- parallel::detectCores() | |
cores | |
### 5.1.2 Build model ---- | |
rf_mod <- | |
rand_forest(mtry = tune(), min_n = tune(), trees = 1000) %>% | |
# tune() is placeholder for later | |
set_engine("ranger", num.threads = cores) %>% | |
set_mode("classification") | |
### 5.1.3 CAUTION: Don't set cores except for random forest ---- | |
## 5.2 Create Recipe ---- | |
#Unlike penalized logistic regression models, random forest models do | |
#not require dummy or normalized predictor variables. | |
rf_recipe <- | |
recipe(children ~ ., data = hotel_other) %>% | |
step_date(arrival_date) %>% | |
step_holiday(arrival_date) %>% | |
step_rm(arrival_date) | |
## 5.3 Create Workflow ---- | |
rf_workflow <- | |
workflow() %>% | |
add_model(rf_mod) %>% | |
add_recipe(rf_recipe) | |
## 5.4 Train and Tune Model ---- | |
### 5.4.1 Show what will be tuned | |
rf_mod %>% | |
parameters() | |
### 5.4.2 space-filling grid ---- | |
set.seed(345) | |
rf_res <- | |
rf_workflow %>% | |
tune_grid(val_set, | |
grid = 25, | |
control = control_grid(save_pred = TRUE), | |
metrics = metric_set(roc_auc)) | |
### 5.4.3 Show the best ---- | |
rf_res %>% | |
show_best(metric = "roc_auc") | |
### 5.4.4 Plot | |
#However, the range of the y-axis indicates that the model is | |
#very robust to the choice of these parameter values — all but | |
#one of the ROC AUC values are greater than 0.90. | |
autoplot(rf_res) | |
### 5.4.5 Select best ---- | |
rf_best <- | |
rf_res %>% | |
select_best(metric = "roc_auc") | |
rf_best | |
### 5.4.6 Filter model to best prediction ---- | |
rf_auc <- | |
rf_res %>% | |
collect_predictions(parameters = rf_best) %>% | |
roc_curve(children, .pred_children) %>% | |
mutate(model = "Random Forest") | |
### 5.4.7 Plot best model ---- | |
bind_rows(rf_auc, lr_auc) %>% | |
ggplot(aes(x = 1 - specificity, y = sensitivity, col = model)) + | |
geom_path(lwd = 1.5, alpha = 0.8) + | |
geom_abline(lty = 3) + | |
coord_equal() + | |
scale_color_viridis_d(option = "plasma", end = .6) | |
# error lr_auc was in previous lesson | |
# 6.0 THE LAST FIT ---- | |
#build parsnip model object again from scratch | |
#take our best hyperparameter values from our random forest model. | |
# set new argument: importance = "impurity" | |
## 6.1 last model ---- | |
last_rf_mod <- | |
rand_forest(mtry = 8, min_n = 7, trees = 1000) %>% | |
set_engine("ranger", num.threads = cores, importance = "impurity") %>% | |
set_mode("classification") | |
## 6.2 last workflow ---- | |
last_rf_workflow <- | |
rf_workflow %>% | |
update_model(last_rf_mod) | |
## 6.3 last fit ---- | |
set.seed(345) | |
last_rf_fit <- | |
last_rf_workflow %>% | |
last_fit(splits) | |
## 6.4 evaluate model ---- | |
last_rf_fit %>% | |
collect_metrics() | |
## 6.5 review variable importance ---- | |
last_rf_fit %>% | |
pluck(".workflow", 1) %>% | |
pull_workflow_fit() %>% | |
vip(num_features = 20) | |
## 6.6 last roc ---- | |
#similar to validation set. good predictor on new data. | |
last_rf_fit %>% | |
collect_predictions() %>% | |
roc_curve(children, .pred_children) %>% | |
autoplot() | |
# 7.0 RESOURCES ---- | |
# Kuhn & Silge: https://www.tmwr.org |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment