Skip to content

Instantly share code, notes, and snippets.

@RobWiederstein
Created May 30, 2021 18:59
Show Gist options
  • Save RobWiederstein/66443eef38a24776327ff994fed375df to your computer and use it in GitHub Desktop.
Save RobWiederstein/66443eef38a24776327ff994fed375df to your computer and use it in GitHub Desktop.
fourth tutorial from tidymodels
#################################################################
## tidymodels ##
## 4 Tune model parameters ##
## url: https://www.tidymodels.org/start/tuning/ ##
#################################################################
# 1.0 INTRODUCTION ----
## 1.1 hyperparameters ----
### 1.1.1 mtry()----
### 1.1.2 learn_rate()----
library(tidymodels) # for the tune package, along with the rest of tidymodels
# Helper packages
library(modeldata) # for the cells data
library(vip) # for variable importance plots
# 2.0 THE CELL IMAGE DATA, REVISITED ----
#labeled by experts as well-segmented (WS) or poorly segmented (PS).
data(cells, package = "modeldata")
# 3.0 PREDICTING IMAGE SEGMENTATION, BUT BETTER ----
#Random forest models are a tree-based ensemble method, and typically
#perform well with default hyperparameters. However, the accuracy of
#some other tree-based models, such as boosted tree models or decision
#tree models, can be sensitive to the values of hyperparameters. In
#this article, we will train a decision tree model.
## 3.1 cost_complexity() ----
#adds a cost, or penalty, to error rates of more complex trees
## 3.2 tree_depth() ----
#helps by stopping our tree from growing after it reaches
#a certain depth.
set.seed(123)
cell_split <- initial_split(cells %>% select(-case),
strata = class)
cell_train <- training(cell_split)
cell_test <- testing(cell_split)
# 4.0 TUNING HYPERPARAMETERS ----
## 4.1 identifies which hyperparameters we plan to tune ----
tune_spec <-
decision_tree(
cost_complexity = tune(),
tree_depth = tune()
) %>%
set_engine("rpart") %>%
set_mode("classification")
tune_spec
## 4.2 dials::grid_regular() ----
tree_grid <- grid_regular(cost_complexity(),
tree_depth(),
levels = 5)
tree_grid
## 4.3 Cross validation folds ----
set.seed(234)
cell_folds <- vfold_cv(cell_train)
# 5.0 MODEL TUNING WITH A GRID ----
## 5.1 workflow -----
set.seed(345)
tree_wf <- workflow() %>%
add_model(tune_spec) %>%
add_formula(class ~ .)
## 5.2 add tuning grid ----
tree_res <-
tree_wf %>%
tune_grid(
resamples = cell_folds,
grid = tree_grid
)
tree_res
## 5.3 Plot ----
tree_res %>%
collect_metrics() %>%
mutate(tree_depth = factor(tree_depth)) %>%
ggplot(aes(cost_complexity, mean, color = tree_depth)) +
geom_line(size = 1.5, alpha = 0.6) +
geom_point(size = 2) +
facet_wrap(~ .metric, scales = "free", nrow = 2) +
scale_x_log10(labels = scales::label_number()) +
scale_color_viridis_d(option = "plasma", begin = .9, end = 0)
## 5.4 Select Best ----
best_tree <- tree_res %>%
select_best("roc_auc")
best_tree
# 6.0 FINALIZING OUR MODEL ----
final_wf <-
tree_wf %>%
finalize_workflow(best_tree)
final_wf
# 7.0 EXPLORING RESULTS ----
## 7.1 fit final model to training data ----
final_tree <-
final_wf %>%
fit(data = cell_train)
final_tree
## 7.2 variable importance ----
library(vip)
final_tree %>%
pull_workflow_fit() %>%
vip(geom = "point")
# 8.0 LAST FIT ----
## 8.1 test data ----
final_fit <-
final_wf %>%
last_fit(cell_split)
## 8.2 collect metrics ----
final_fit %>%
collect_metrics()
## 8.3 plot roc curve ----
final_fit %>%
collect_predictions() %>%
roc_curve(class, .pred_PS) %>%
autoplot()
## 8.4 Other hyperparameters?? ----
args(decision_tree)
## 8.5 https://www.tidymodels.org/find/parsnip/#models ----
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment