Created
July 30, 2020 03:16
-
-
Save topepo/c17a504b070a4bc87e847c0d851ae21a to your computer and use it in GitHub Desktop.
Example code for using tidymodels, recipes, parsnip, and DALEX
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(tidymodels) | |
library(plsmod) | |
library(DALEX) | |
theme_set(theme_bw()) | |
## ── Attaching packages ───────────────────────────────── tidymodels 0.1.1 ── | |
## ✓ broom 0.7.0 ✓ recipes 0.1.13 | |
## ✓ dials 0.0.8 ✓ rsample 0.0.7 | |
## ✓ dplyr 1.0.0 ✓ tibble 3.0.3 | |
## ✓ ggplot2 3.3.2 ✓ tidyr 1.1.0 | |
## ✓ infer 0.5.2 ✓ tune 0.1.1 | |
## ✓ modeldata 0.0.2 ✓ workflows 0.1.2 | |
## ✓ parsnip 0.1.2 ✓ yardstick 0.0.7 | |
## ✓ purrr 0.3.4 | |
## ----------------------------------------------------------------------------- | |
# Info on data at https://bookdown.org/max/FES/chicago-intro.html | |
data(Chicago, package = "modeldata") | |
# The last two weeks are the test set | |
data_split <- initial_time_split(Chicago, prop = 0.997543) | |
chi_train <- training(data_split) | |
chi_test <- training(data_split) | |
# Our resampling scheme will emulate this using rolling forecasting origin resampling with | |
# Moving analysis sets of 15 years moving over 28-day periods | |
# An assessment set of the most recent 28 days of data | |
chi_folds <- rolling_origin( | |
chi_train, | |
initial = 364 * 15, | |
assess = 7 * 2, | |
skip = 7 * 2, | |
cumulative = FALSE | |
) | |
## ----------------------------------------------------------------------------- | |
# Generate features - mostly changing data into indicators | |
us_holidays <- c('USChristmasDay', 'USColumbusDay', 'USCPulaskisBirthday', | |
'USDecorationMemorialDay', 'USElectionDay', 'USGoodFriday', | |
'USInaugurationDay', 'USIndependenceDay', 'USLaborDay', | |
'USLincolnsBirthday', 'USMemorialDay', 'USMLKingsBirthday', | |
'USNewYearsDay', 'USPresidentsDay', 'USThanksgivingDay', | |
'USVeteransDay', 'USWashingtonsBirthday') | |
chi_rec <- | |
recipe(ridership ~ ., data = chi_train %>% slice(1:5)) %>% | |
step_date(date) %>% | |
step_holiday(date, holidays = us_holidays) %>% | |
step_rm(date) %>% | |
step_dummy(all_nominal()) %>% | |
step_zv(all_predictors()) %>% | |
step_normalize(all_predictors()) | |
## ----------------------------------------------------------------------------- | |
# Since the lagged station data are highly correlated, use PLS to do the fits | |
pls_mod <- | |
pls(num_comp = tune(), num_terms = tune()) %>% | |
set_engine("mixOmics") %>% | |
set_mode("regression") | |
pls_wflow <- | |
workflow() %>% | |
add_recipe(chi_rec) %>% | |
add_model(pls_mod) | |
## ----------------------------------------------------------------------------- | |
# Tune over the number of PLS components and amount of sparsity. | |
pls_grid <- crossing(num_comp = 1:20, num_terms = c(10, 15, 20)) | |
ctrl <- control_grid(save_pred = TRUE) | |
pls_res <- | |
pls_wflow %>% | |
tune_grid(resamples = chi_folds, grid = pls_grid, control = ctrl) | |
autoplot(pls_res, metric = "rmse") | |
## ----------------------------------------------------------------------------- | |
# Pick the smallest RMSE parameters and do diagnostic plots | |
pls_param <- select_best(pls_res, metric = "rmse") | |
assess_start_date <- | |
map_dfr(chi_folds$splits, ~ assessment(.x) %>% select(date) %>% slice(1)) %>% | |
bind_cols(chi_folds %>% select(id)) | |
rmse_over_time <- | |
pls_res %>% | |
collect_metrics(summarize = FALSE) %>% | |
filter(.metric == "rmse") %>% | |
inner_join(pls_param, by = c("num_terms", "num_comp", ".config")) %>% | |
inner_join(assess_start_date, by = "id") | |
ggplot(rmse_over_time, aes(x = date, y = .estimate)) + | |
geom_point() + | |
geom_line() + | |
scale_x_date(date_labels = "%b %d %Y") | |
pls_res %>% | |
collect_predictions(parameters = pls_param) %>% | |
ggplot(aes(x = ridership, y = .pred)) + | |
geom_abline(lty = 2) + | |
geom_point(alpha = .2) + | |
coord_obs_pred() | |
## ----------------------------------------------------------------------------- | |
# Substitute in the final PLS parameters and re-fit | |
pls_fit <- | |
pls_wflow %>% | |
finalize_workflow(pls_param) %>% | |
fit(data = chi_train) | |
## ----------------------------------------------------------------------------- | |
tm_reg_pred <- function(x, y) { | |
predict(x, new_data = y) %>% pull(.pred) | |
} | |
chi_expl <- | |
explain( | |
pls_fit, | |
data = chi_train %>% select(-ridership), | |
y = chi_train$ridership, | |
predict_function = tm_reg_pred | |
) | |
## ----------------------------------------------------------------------------- | |
set.seed(2983) | |
perm_var_imps <- | |
variable_importance(chi_expl, type = "ratio", n_sample = 20) %>% | |
as_tibble() %>% | |
filter(!grepl("^_", variable)) %>% | |
group_by(variable) %>% | |
summarize( | |
loss = mean(dropout_loss), | |
lower = quantile(dropout_loss, probs = .05), | |
upper = quantile(dropout_loss, probs = .95), | |
.groups = "drop" | |
) %>% | |
ungroup() %>% | |
mutate(variable = reorder(variable, loss)) | |
perm_var_imps %>% | |
arrange(desc(loss)) %>% | |
top_n(20, loss) %>% | |
ggplot(aes(y = variable, x = loss)) + | |
geom_bar(stat = "identity") + | |
labs(y = NULL, x = "Loss in model performance") | |
perm_var_imps %>% | |
arrange(desc(loss)) %>% | |
top_n(20, loss) %>% | |
ggplot(aes(y = variable, x = loss)) + | |
geom_errorbar(aes(xmin = lower, xmax = upper), width = .2) + | |
geom_point() + | |
labs(y = NULL, x = "Loss in model performance") | |
variable_effect(chi_expl, variables = "Clark_Lake", type = "partial_dependency") %>% | |
ggplot(aes(x = `_x_`, y = `_yhat_`)) + | |
geom_path() + | |
labs(x = "Clark and Lake station entries", y = "Predicted Ridership") | |
# Doesn't seem to work: can't handle dates? | |
variable_effect(chi_expl, variables = "date", type = "partial_dependency") %>% | |
ggplot(aes(x = `_x_`, y = `_yhat_`)) + | |
geom_path() + | |
labs(x = "Date", y = "Predicted Ridership") | |
# Reports date as numeric; not sure if a formatting issue or was date converted | |
# to numeric (I hope not)? | |
variable_attribution(chi_expl, chi_test %>% slice(1)) %>% | |
plot() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment