Skip to content

Instantly share code, notes, and snippets.

@simonpcouch
Created May 11, 2022 17:05
Show Gist options
  • Save simonpcouch/c5e738d359884d5ccd77606586dc7d15 to your computer and use it in GitHub Desktop.
Save simonpcouch/c5e738d359884d5ccd77606586dc7d15 to your computer and use it in GitHub Desktop.

This issue came up in a conversation with @\mine-cetinkaya-rundel about teaching introductory stats / modeling courses using the tidymodels. I feel that, in some ways, parsnip’s guardrails re: augment make teaching broom’s principles fussier than it ought to. Fitting a model and passing it to each tidier:

library(tidyverse)
library(tidymodels)
library(palmerpenguins)

penguins <- drop_na(penguins)

penguins_tr <- penguins[1:200,]
penguins_te <- penguins[201:nrow(penguins),]

pars_fit <- linear_reg() %>%
  set_engine("lm") %>%
  fit(body_mass_g ~ flipper_length_mm, data = penguins)

# works
tidy(pars_fit)
#> # A tibble: 2 × 5
#>   term              estimate std.error statistic   p.value
#>   <chr>                <dbl>     <dbl>     <dbl>     <dbl>
#> 1 (Intercept)        -5872.     310.       -18.9 1.18e- 54
#> 2 flipper_length_mm     50.2      1.54      32.6 3.13e-105

# works
glance(pars_fit)
#> # A tibble: 1 × 12
#>   r.squared adj.r.squared sigma statistic   p.value    df logLik   AIC   BIC
#>       <dbl>         <dbl> <dbl>     <dbl>     <dbl> <dbl>  <dbl> <dbl> <dbl>
#> 1     0.762         0.761  393.     1060. 3.13e-105     1 -2461. 4928. 4940.
#> # … with 3 more variables: deviance <dbl>, df.residual <int>, nobs <int>

# oh-
augment(pars_fit)
#> Error in augment.model_fit(pars_fit): argument "new_data" is missing, with no default

I understand that the intention here was to guard folks from predicting on the training set, and maybe the conclusion here is that the teaching moment re: predicting on training data needs to happen this early on. I feel that the current approach feels 1) dogmatic and, in some cases, 2) ends up encouraging the opposite behavior.

re: 1) broom’s augment methods distinguish between data (i.e. training data) and newdata in determining output, and return otherfit info when supplied the former rather than the latter that’s only well-defined for training data. augmenting with training data accommodates discussing these values:

lm_fit <- lm(body_mass_g ~ flipper_length_mm, data = penguins_tr)

# default: retrieve `penguins_tr` and pass as `data`
augment(lm_fit)
#> # A tibble: 200 × 8
#>    body_mass_g flipper_length_… .fitted .resid    .hat .sigma .cooksd .std.resid
#>          <int>            <int>   <dbl>  <dbl>   <dbl>  <dbl>   <dbl>      <dbl>
#>  1        3750              181   3273.  477.  0.0127    415. 8.58e-3     1.16  
#>  2        3800              186   3523.  277.  0.00862   416. 1.96e-3     0.671 
#>  3        3250              195   3972. -722.  0.00511   413. 7.81e-3    -1.74  
#>  4        3450              193   3872. -422.  0.00547   415. 2.85e-3    -1.02  
#>  5        3650              190   3722.  -72.2 0.00645   416. 9.89e-5    -0.175 
#>  6        3625              181   3273.  352.  0.0127    415. 4.67e-3     0.853 
#>  7        4675              195   3972.  703.  0.00511   413. 7.42e-3     1.70  
#>  8        3200              182   3323. -123.  0.0117    416. 5.31e-4    -0.299 
#>  9        3800              191   3772.   27.9 0.00606   416. 1.39e-5     0.0675
#> 10        4400              198   4121.  279.  0.00504   416. 1.15e-3     0.674 
#> # … with 190 more rows

# pass `penguins_tr` as `data` explicitly
augment(lm_fit, data = penguins_tr)
#> # A tibble: 200 × 14
#>    species island    bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
#>    <fct>   <fct>              <dbl>         <dbl>             <int>       <int>
#>  1 Adelie  Torgersen           39.1          18.7               181        3750
#>  2 Adelie  Torgersen           39.5          17.4               186        3800
#>  3 Adelie  Torgersen           40.3          18                 195        3250
#>  4 Adelie  Torgersen           36.7          19.3               193        3450
#>  5 Adelie  Torgersen           39.3          20.6               190        3650
#>  6 Adelie  Torgersen           38.9          17.8               181        3625
#>  7 Adelie  Torgersen           39.2          19.6               195        4675
#>  8 Adelie  Torgersen           41.1          17.6               182        3200
#>  9 Adelie  Torgersen           38.6          21.2               191        3800
#> 10 Adelie  Torgersen           34.6          21.1               198        4400
#> # … with 190 more rows, and 8 more variables: sex <fct>, year <int>,
#> #   .fitted <dbl>, .resid <dbl>, .hat <dbl>, .sigma <dbl>, .cooksd <dbl>,
#> #   .std.resid <dbl>

# pass `penguins_te` as `newdata`
augment(lm_fit, newdata = penguins_te)
#> # A tibble: 133 × 10
#>    species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
#>    <fct>   <fct>           <dbl>         <dbl>             <int>       <int>
#>  1 Gentoo  Biscoe           45            15.4               220        5050
#>  2 Gentoo  Biscoe           43.8          13.9               208        4300
#>  3 Gentoo  Biscoe           45.5          15                 220        5000
#>  4 Gentoo  Biscoe           43.2          14.5               208        4450
#>  5 Gentoo  Biscoe           50.4          15.3               224        5550
#>  6 Gentoo  Biscoe           45.3          13.8               208        4200
#>  7 Gentoo  Biscoe           46.2          14.9               221        5300
#>  8 Gentoo  Biscoe           45.7          13.9               214        4400
#>  9 Gentoo  Biscoe           54.3          15.7               231        5650
#> 10 Gentoo  Biscoe           45.8          14.2               219        4700
#> # … with 123 more rows, and 4 more variables: sex <fct>, year <int>,
#> #   .fitted <dbl>, .resid <dbl>

Note the missing columns in the newdata output.

re: 2) If one tries to work with parsnip’s interface to get these values, there’s no interface to data, so one must supply the training data to newdata. In results, they don’t get this fit info and the .fitted output is renamed to .pred.

augment(pars_fit, new_data = penguins_tr)
#> # A tibble: 200 × 10
#>    species island    bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
#>    <fct>   <fct>              <dbl>         <dbl>             <int>       <int>
#>  1 Adelie  Torgersen           39.1          18.7               181        3750
#>  2 Adelie  Torgersen           39.5          17.4               186        3800
#>  3 Adelie  Torgersen           40.3          18                 195        3250
#>  4 Adelie  Torgersen           36.7          19.3               193        3450
#>  5 Adelie  Torgersen           39.3          20.6               190        3650
#>  6 Adelie  Torgersen           38.9          17.8               181        3625
#>  7 Adelie  Torgersen           39.2          19.6               195        4675
#>  8 Adelie  Torgersen           41.1          17.6               182        3200
#>  9 Adelie  Torgersen           38.6          21.2               191        3800
#> 10 Adelie  Torgersen           34.6          21.1               198        4400
#> # … with 190 more rows, and 4 more variables: sex <fct>, year <int>,
#> #   .pred <dbl>, .resid <dbl>

augment(pars_fit, data = penguins_tr, new_data = penguins_tr)
#> # A tibble: 200 × 10
#>    species island    bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
#>    <fct>   <fct>              <dbl>         <dbl>             <int>       <int>
#>  1 Adelie  Torgersen           39.1          18.7               181        3750
#>  2 Adelie  Torgersen           39.5          17.4               186        3800
#>  3 Adelie  Torgersen           40.3          18                 195        3250
#>  4 Adelie  Torgersen           36.7          19.3               193        3450
#>  5 Adelie  Torgersen           39.3          20.6               190        3650
#>  6 Adelie  Torgersen           38.9          17.8               181        3625
#>  7 Adelie  Torgersen           39.2          19.6               195        4675
#>  8 Adelie  Torgersen           41.1          17.6               182        3200
#>  9 Adelie  Torgersen           38.6          21.2               191        3800
#> 10 Adelie  Torgersen           34.6          21.1               198        4400
#> # … with 190 more rows, and 4 more variables: sex <fct>, year <int>,
#> #   .pred <dbl>, .resid <dbl>

I’m on board for guardrails here, but think we can be more accommodating here and use prompts to encourage the kind of behavior we want to see.

I think a better approach here could be to allow for not passing new_data, but warn/message about it:

i.e. give this output on augment(pars_fit):

augment(lm_fit)
#> # A tibble: 200 × 8
#>    body_mass_g flipper_length_… .fitted .resid    .hat .sigma .cooksd .std.resid
#>          <int>            <int>   <dbl>  <dbl>   <dbl>  <dbl>   <dbl>      <dbl>
#>  1        3750              181   3273.  477.  0.0127    415. 8.58e-3     1.16  
#>  2        3800              186   3523.  277.  0.00862   416. 1.96e-3     0.671 
#>  3        3250              195   3972. -722.  0.00511   413. 7.81e-3    -1.74  
#>  4        3450              193   3872. -422.  0.00547   415. 2.85e-3    -1.02  
#>  5        3650              190   3722.  -72.2 0.00645   416. 9.89e-5    -0.175 
#>  6        3625              181   3273.  352.  0.0127    415. 4.67e-3     0.853 
#>  7        4675              195   3972.  703.  0.00511   413. 7.42e-3     1.70  
#>  8        3200              182   3323. -123.  0.0117    416. 5.31e-4    -0.299 
#>  9        3800              191   3772.   27.9 0.00606   416. 1.39e-5     0.0675
#> 10        4400              198   4121.  279.  0.00504   416. 1.15e-3     0.674 
#> # … with 190 more rows

with the prompt:

#> Adding information about the model fit using the training data. Predict with new data using the `new_data` argument to assess predictive performance. 

If the user passes the training data to new_data, we can detect it with model.frame(x), and warn then too:

#> The training data was passed as `new_data`. Model predictions may appear overly performant; please interpret cautiously. See `?repredicting` to learn more.

Where we could write a dedicated doc topic on predicting on the training data.

Let me know if yall would welcome a PR here. :)

Created on 2022-05-11 by the reprex package (v2.0.1)

@mine-cetinkaya-rundel
Copy link

The first prompt

Adding information about the model fit using the training data...

would presumably appear even if one never did a test/train split, i.e., with

augment(pars_fit)

which uses the full dataset. In that case, the user could see the word "training data" when they have never thought about their modeling in the testing/training context. How about

Adding information about the model fit using the data that was used to fit the model (e.g., training data) ...

or maybe even without the parenthetical (e.g., ...)

Given that glance() will happily report R-squared on these data (which relies on re-predicting, after all), I think we could do this without any message as well, for consistency. It does mean dropping guardrails further so there might be resistance to that. I'm not opposed to messaging, just worry about messaging before learners are ready to interpret the message. Replacing (or defining) the jargon as I suggested above can help with that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment