Simon Couch
# with tidymodels/container#12 and tidymodels/workflows#225
library(tidymodels)
── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
✔ broom 1.0.5.9000 ✔ recipes 1.0.10.9000
✔ dials 1.2.1 ✔ rsample 1.2.1
✔ dplyr 1.1.4 ✔ tibble 3.2.1
✔ ggplot2 3.5.1 ✔ tidyr 1.3.1
✔ infer 1.0.6.9000 ✔ tune 1.2.1
✔ modeldata 1.3.0 ✔ workflows 1.1.4.9000
✔ parsnip 1.2.1.9001 ✔ workflowsets 1.1.0
✔ purrr 1.0.2 ✔ yardstick 1.3.1
── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
✖ purrr::discard() masks scales::discard()
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag() masks stats::lag()
✖ recipes::step() masks stats::step()
• Use suppressPackageStartupMessages() to eliminate package startup messages
The introduction of postprocessors in workflows introduces new questions about how data is allotted during workflow fitting.
# create example data
y <- seq(0, 7, .1)
dat <- data.frame(y = y, x = y + (y-3)^2)
plot(dat$x, dat$y)
# construct workflows
post <- container::container("regression", "regression")
post <- container::adjust_numeric_calibration(post, "linear")
wflow_simple <- workflow(y ~ ., parsnip::linear_reg())
wflow_post <- add_container(wflow_simple, post)
# train workflow
wf_simple_fit <- fit(wflow_simple, dat)
wf_post_fit <- fit(wflow_post, dat)
wf_simple_fit
has trained the preprocessor and model on dat
.
wf_post_fit
has trained both of those as well as the postprocessor
on dat
.
Note that (calibration) postprocessors are trained on model predictions.
So, in this case, when fitting wf_post_fit
:
-
The workflow trains the preprocessor and model on
dat
as usual. -
Then, workflows re-predicts
dat
with the preprocessor and model (identical to output that would be returned frompredict(wf_simple_fit, dat)
) and trains the post-processor on those re-predictions.
Confirming this is what actually happens:
wflow_simple_preds <- augment(wf_simple_fit, dat)
post_trained <- fit(post, wflow_simple_preds, y, .pred)
wflow_manual_preds <- predict(post_trained, wflow_simple_preds)
wflow_post_preds <- predict(wf_post_fit, dat)
all(wflow_manual_preds[".pred"] == wflow_post_preds)
[1] TRUE
What we actually want is to train the post-processor on predictions generated from the preprocessor/model that that pair wasn’t trained on.
Note
This is not an issue for workflows that don’t have postprocessors or workflows whose postprocessor don’t require training. As of now, calibration adjustments (
adjust_*_calibration()
) are the only postprocessors that require training.
We haven’t encountered this probably directly in tune as, for
computational efficiency reasons, it trains the preprocessor and model
(and now postprocessor) separately rather than using fit.workflow()
,
so we can just pass different data to each. workflows will need to
address this problem somehow, though.
This approach would mirror tune’s current approach. Internally, tidymodels takes care of splitting data up for training the two components and (in an interface we haven’t figured out yet in tune) shows curious users how that split was determined.
So, the user still just supplies dat
to fit.workflow()
and workflows
does a fancy version of initial_split(dat)
internally; one portion
trains the preprocessor and model, and then that partially trained
workflow predicts on the second portion and those predictions are used
to train the postprocessor.
Some notable pros here:
- Users can still just supply
data
as usual—I think thatdata
as a type-stable argument here is actually really important as both users and ourselves can programmatically passdata
to fit a workflow without accounting for the edge case that the workflow has a postprocessor and that postprocessor has a calibrator (and thus needs training).
- tidymodels protects users from re-prediction under the hood.
Cons here:
-
Users (and we) can’t control which data ends up in which split.
- This is especially a bummer (and even dangerous) when
data
is the output of a resampling function that affects statistical independence of rows. i.e. ifdata
istraining(bootstraps(dat)$splits[[1]])
, thendata
will contain duplicate rows that could end up in both the training sets for the preprocessor/model and the postprocessor. Similarly problematic story for time series.
- This is especially a bummer (and even dangerous) when
It is worth noting that this approach can be reproducible and auditable. Re auditable, note that:
# currently:
names(wf_post_fit$post)
[1] "actions" "post"
We can just add a split
slot to document the splits that were used to
train either. Doing so would make that split
accessible in tuning
results via control_*(extract)
.
Require the user to pass a split
to the workflow as data
when the
workflow contains a postprocessor that requires training.
Pros:
- Allows for fine-grained control of which data is used to train which elements of the workflow.
Cons:
- I can certainly imagine users just supplying
initial_split(data)
to that argument even in problematic situations like those mentioned above.
:/
I think that a combination of 1 and 2 is likely most helpful here.
-
fit.workflow()
can take in data-framedata
and will make an internal split when it needs to. (Notably, as a happy path—I’d argue no need to message or warn here.) -
fit.workflow()
can also takersplit
data in the case when a post-processor that requires training is present.- We’d likely want to warn or error here when the workflow doesn’t
actually require
data
to be supplied as anrsplit
.
- We’d likely want to warn or error here when the workflow doesn’t
actually require
One last point that seems worth thinking through is the possibility of a “partially trained” workflow. Postprocessors, theoretically, could be added, updated, removed, and/or trained without any need for changing the underlying preprocessor and model fit. That said, the same is technically true for preprocessors in the context of model fits, but the package currently doesn’t accommodate workflows where the preprocessor is trained but the model isn’t.
My thought here is that the .fit_*()
functions, which allow for
training only the preprocessor, model, and now postprocessor—and are the
functions used by tune to do so—are the sole interface to partially
train workflows and likely ought to stay this way.