Skip to content

Instantly share code, notes, and snippets.

@topepo
Created May 6, 2021 21:57
Show Gist options
  • Save topepo/00a7cb628efc84db88aff95fd26ee966 to your computer and use it in GitHub Desktop.
Save topepo/00a7cb628efc84db88aff95fd26ee966 to your computer and use it in GitHub Desktop.
Regression diagnostic plots for shinymodels
library(tidymodels)
library(rules)
tidymodels_prefer()
theme_set(theme_bw())
library(doMC)
registerDoMC(cores = 20)
# ------------------------------------------------------------------------------
data(Chicago)
set.seed(1)
chi_split <- initial_split(Chicago)
chi_train <- training(chi_split)
chi_test <- testing(chi_split)
set.seed(2)
chi_folds <- sliding_period(chi_train, date, "month", assess_stop = 100, step = 1)
chi_folds
# ------------------------------------------------------------------------------
chi_rec <-
recipe(ridership ~ ., data = chi_train) %>%
step_date(date) %>%
step_holiday(date) %>%
update_role(date, new_role = "id")
cubist_spec <-
cubist_rules(committees = 25, neighbors = 9) %>%
set_engine("Cubist")
cubist_wflow <-
workflow() %>%
add_model(cubist_spec) %>%
add_recipe(chi_rec)
# ------------------------------------------------------------------------------
ctrl_rs <- control_resamples(save_pred = TRUE)
cubist_res <-
cubist_wflow %>%
fit_resamples(resamples = chi_folds, control = ctrl_rs)
cubist_in_sample_predictions <-
augment(cubist_res) %>%
# for demo, add the day of the week
mutate(day = lubridate::wday(date, label = TRUE, abbr = FALSE))
cubist_test_res <-
cubist_wflow %>%
last_fit(split = chi_split)
cubist_test_predictions <- augment(cubist_test_res)
# ------------------------------------------------------------------------------
cubist_in_sample_predictions %>%
ggplot(aes(x = ridership, y = .pred)) +
geom_abline(lty = 2) +
geom_point(alpha = .3) +
coord_obs_pred() +
ggtitle("Observed vs predicted")
cubist_in_sample_predictions %>%
ggplot(aes(x = .pred, y = .resid)) +
geom_hline(yintercept = 0, lty = 2) +
geom_point(alpha = .3) +
ggtitle("Residuals vs predicted")
# ------------------------------------------------------------------------------
cubist_in_sample_predictions %>%
ggplot(aes(sample = .resid)) +
stat_qq_line(lty = 2) +
stat_qq(alpha = .2) +
ggtitle("Normal probability plot")
# ------------------------------------------------------------------------------
cubist_in_sample_predictions %>%
ggplot(aes(x = Clark_Lake, y = .resid)) +
geom_point(alpha = .3) +
ggtitle("Truth vs numeric predictor")
cubist_in_sample_predictions %>%
ggplot(aes(y = reorder(day, .resid), x = .resid)) +
geom_point(alpha = .3) +
ylab("Day") +
ggtitle("Truth vs factor predictor")
@topepo
Copy link
Author

topepo commented May 27, 2021

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
library(rules)
#> 
#> Attaching package: 'rules'
#> The following object is masked from 'package:dials':
#> 
#>     max_rules
tidymodels_prefer()
theme_set(theme_bw())

library(doMC)
#> Loading required package: foreach
#> 
#> Attaching package: 'foreach'
#> The following objects are masked from 'package:purrr':
#> 
#>     accumulate, when
#> Loading required package: iterators
#> Loading required package: parallel
registerDoMC(cores = 20)
data(Chicago)
set.seed(1)
chi_split <- initial_split(Chicago)
chi_train <- training(chi_split)
chi_test <- testing(chi_split)

set.seed(2)
chi_folds <- sliding_period(chi_train, date, "month", assess_stop = 100, step = 1)
chi_folds
#> # Sliding period resampling 
#> # A tibble: 88 x 2
#>    splits            id     
#>    <list>            <chr>  
#>  1 <split [9/2272]>  Slice01
#>  2 <split [21/2277]> Slice02
#>  3 <split [20/2285]> Slice03
#>  4 <split [25/2281]> Slice04
#>  5 <split [20/2284]> Slice05
#>  6 <split [22/2285]> Slice06
#>  7 <split [27/2283]> Slice07
#>  8 <split [24/2284]> Slice08
#>  9 <split [18/2290]> Slice09
#> 10 <split [26/2286]> Slice10
#> # … with 78 more rows
chi_rec <-
  recipe(ridership ~ ., data = chi_train) %>%
  step_date(date) %>%
  step_holiday(date) %>%
  update_role(date, new_role = "id")

cubist_spec <-
  cubist_rules(committees = 25, neighbors = 9) %>%
  set_engine("Cubist")

cubist_wflow <-
  workflow() %>%
  add_model(cubist_spec) %>%
  add_recipe(chi_rec)
ctrl_rs <- control_resamples(save_pred = TRUE)

cubist_res <-
  cubist_wflow %>%
  fit_resamples(resamples = chi_folds, control = ctrl_rs)

cubist_in_sample_predictions <-
  augment(cubist_res) %>%
  # for demo, add the day of the week
  mutate(day = lubridate::wday(date, label = TRUE, abbr = FALSE))
#> Warning: The orginal data had 4274 rows but there were 4265 hold-out
#> predictions.

cubist_test_res <-
  cubist_wflow %>%
  last_fit(split = chi_split)

cubist_test_predictions <-  augment(cubist_test_res)
cubist_in_sample_predictions %>%
  ggplot(aes(x = ridership, y = .pred)) +
  geom_abline(lty = 2) +
  geom_point(alpha = .3) +
  coord_obs_pred() +
  ggtitle("Observed vs predicted")
#> Warning: Removed 9 rows containing missing values (geom_point).

cubist_in_sample_predictions %>%
  ggplot(aes(x = .pred, y = .resid)) +
  geom_hline(yintercept = 0, lty = 2) +
  geom_point(alpha = .3) +
  ggtitle("Residuals vs predicted")
#> Warning: Removed 9 rows containing missing values (geom_point).

cubist_in_sample_predictions %>%
  ggplot(aes(sample = .resid)) +
  stat_qq_line(lty = 2) +
  stat_qq(alpha = .2) +
  ggtitle("Normal probability plot")
#> Warning: Removed 9 rows containing non-finite values (stat_qq_line).
#> Warning: Removed 9 rows containing non-finite values (stat_qq).

cubist_in_sample_predictions %>%
  ggplot(aes(x = Clark_Lake, y = .resid)) +
  geom_point(alpha = .3) +
  ggtitle("Truth vs numeric predictor")
#> Warning: Removed 9 rows containing missing values (geom_point).

cubist_in_sample_predictions %>%
  ggplot(aes(y = reorder(day, .resid), x = .resid)) +
  geom_point(alpha = .3) +
  ylab("Day") +
  ggtitle("Truth vs factor predictor")
#> Warning: Removed 9 rows containing missing values (geom_point).

Created on 2021-05-27 by the reprex package (v1.0.0.9000)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment