Skip to content

Instantly share code, notes, and snippets.

@topepo
Created July 30, 2020 03:16
Show Gist options
  • Save topepo/c17a504b070a4bc87e847c0d851ae21a to your computer and use it in GitHub Desktop.
Save topepo/c17a504b070a4bc87e847c0d851ae21a to your computer and use it in GitHub Desktop.
Example code for using tidymodels, recipes, parsnip, and DALEX
library(tidymodels)
library(plsmod)
library(DALEX)
theme_set(theme_bw())
## ── Attaching packages ───────────────────────────────── tidymodels 0.1.1 ──
## ✓ broom 0.7.0 ✓ recipes 0.1.13
## ✓ dials 0.0.8 ✓ rsample 0.0.7
## ✓ dplyr 1.0.0 ✓ tibble 3.0.3
## ✓ ggplot2 3.3.2 ✓ tidyr 1.1.0
## ✓ infer 0.5.2 ✓ tune 0.1.1
## ✓ modeldata 0.0.2 ✓ workflows 0.1.2
## ✓ parsnip 0.1.2 ✓ yardstick 0.0.7
## ✓ purrr 0.3.4
## -----------------------------------------------------------------------------
# Info on data at https://bookdown.org/max/FES/chicago-intro.html
data(Chicago, package = "modeldata")
# The last two weeks are the test set
data_split <- initial_time_split(Chicago, prop = 0.997543)
chi_train <- training(data_split)
chi_test <- training(data_split)
# Our resampling scheme will emulate this using rolling forecasting origin resampling with
# Moving analysis sets of 15 years moving over 28-day periods
# An assessment set of the most recent 28 days of data
chi_folds <- rolling_origin(
chi_train,
initial = 364 * 15,
assess = 7 * 2,
skip = 7 * 2,
cumulative = FALSE
)
## -----------------------------------------------------------------------------
# Generate features - mostly changing data into indicators
us_holidays <- c('USChristmasDay', 'USColumbusDay', 'USCPulaskisBirthday',
'USDecorationMemorialDay', 'USElectionDay', 'USGoodFriday',
'USInaugurationDay', 'USIndependenceDay', 'USLaborDay',
'USLincolnsBirthday', 'USMemorialDay', 'USMLKingsBirthday',
'USNewYearsDay', 'USPresidentsDay', 'USThanksgivingDay',
'USVeteransDay', 'USWashingtonsBirthday')
chi_rec <-
recipe(ridership ~ ., data = chi_train %>% slice(1:5)) %>%
step_date(date) %>%
step_holiday(date, holidays = us_holidays) %>%
step_rm(date) %>%
step_dummy(all_nominal()) %>%
step_zv(all_predictors()) %>%
step_normalize(all_predictors())
## -----------------------------------------------------------------------------
# Since the lagged station data are highly correlated, use PLS to do the fits
pls_mod <-
pls(num_comp = tune(), num_terms = tune()) %>%
set_engine("mixOmics") %>%
set_mode("regression")
pls_wflow <-
workflow() %>%
add_recipe(chi_rec) %>%
add_model(pls_mod)
## -----------------------------------------------------------------------------
# Tune over the number of PLS components and amount of sparsity.
pls_grid <- crossing(num_comp = 1:20, num_terms = c(10, 15, 20))
ctrl <- control_grid(save_pred = TRUE)
pls_res <-
pls_wflow %>%
tune_grid(resamples = chi_folds, grid = pls_grid, control = ctrl)
autoplot(pls_res, metric = "rmse")
## -----------------------------------------------------------------------------
# Pick the smallest RMSE parameters and do diagnostic plots
pls_param <- select_best(pls_res, metric = "rmse")
assess_start_date <-
map_dfr(chi_folds$splits, ~ assessment(.x) %>% select(date) %>% slice(1)) %>%
bind_cols(chi_folds %>% select(id))
rmse_over_time <-
pls_res %>%
collect_metrics(summarize = FALSE) %>%
filter(.metric == "rmse") %>%
inner_join(pls_param, by = c("num_terms", "num_comp", ".config")) %>%
inner_join(assess_start_date, by = "id")
ggplot(rmse_over_time, aes(x = date, y = .estimate)) +
geom_point() +
geom_line() +
scale_x_date(date_labels = "%b %d %Y")
pls_res %>%
collect_predictions(parameters = pls_param) %>%
ggplot(aes(x = ridership, y = .pred)) +
geom_abline(lty = 2) +
geom_point(alpha = .2) +
coord_obs_pred()
## -----------------------------------------------------------------------------
# Substitute in the final PLS parameters and re-fit
pls_fit <-
pls_wflow %>%
finalize_workflow(pls_param) %>%
fit(data = chi_train)
## -----------------------------------------------------------------------------
tm_reg_pred <- function(x, y) {
predict(x, new_data = y) %>% pull(.pred)
}
chi_expl <-
explain(
pls_fit,
data = chi_train %>% select(-ridership),
y = chi_train$ridership,
predict_function = tm_reg_pred
)
## -----------------------------------------------------------------------------
set.seed(2983)
perm_var_imps <-
variable_importance(chi_expl, type = "ratio", n_sample = 20) %>%
as_tibble() %>%
filter(!grepl("^_", variable)) %>%
group_by(variable) %>%
summarize(
loss = mean(dropout_loss),
lower = quantile(dropout_loss, probs = .05),
upper = quantile(dropout_loss, probs = .95),
.groups = "drop"
) %>%
ungroup() %>%
mutate(variable = reorder(variable, loss))
perm_var_imps %>%
arrange(desc(loss)) %>%
top_n(20, loss) %>%
ggplot(aes(y = variable, x = loss)) +
geom_bar(stat = "identity") +
labs(y = NULL, x = "Loss in model performance")
perm_var_imps %>%
arrange(desc(loss)) %>%
top_n(20, loss) %>%
ggplot(aes(y = variable, x = loss)) +
geom_errorbar(aes(xmin = lower, xmax = upper), width = .2) +
geom_point() +
labs(y = NULL, x = "Loss in model performance")
variable_effect(chi_expl, variables = "Clark_Lake", type = "partial_dependency") %>%
ggplot(aes(x = `_x_`, y = `_yhat_`)) +
geom_path() +
labs(x = "Clark and Lake station entries", y = "Predicted Ridership")
# Doesn't seem to work: can't handle dates?
variable_effect(chi_expl, variables = "date", type = "partial_dependency") %>%
ggplot(aes(x = `_x_`, y = `_yhat_`)) +
geom_path() +
labs(x = "Date", y = "Predicted Ridership")
# Reports date as numeric; not sure if a formatting issue or was date converted
# to numeric (I hope not)?
variable_attribution(chi_expl, chi_test %>% slice(1)) %>%
plot()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment