library(tidyverse)
library(tidymodels)
devtools::source_gist("https://gist.github.com/brshallo/3db2cd25172899f91b196a90d5980690")
#> i Sourcing https://gist.githubusercontent.com/brshallo/3db2cd25172899f91b196a90d5980690/raw/5d6731b63fd75e09e7a5e1e33134c389f6209652/predict-interval.R
#> i SHA-1 hash of file is 3b41b16d53af745f880d21b4056fe18412b347e6
data <- tibble(x = abs(rnorm(1000)),
y = x + rnorm(1000, sd = 0.7) * x / (max(x))
)
rec <- recipe(y ~ ., data = data)
mod <- parsnip::linear_reg() %>%
set_engine("lm") %>%
set_mode("regression")
workflow <- workflows::workflow() %>%
add_recipe(rec) %>%
add_model(mod)
prepped_int <- workflow %>%
prep_interval(data)
prepped_int %>%
predict_interval(data, probs = c(0.2, 0.50, 0.975)) %>%
bind_cols(data, .) %>%
ggplot(aes(x = x))+
geom_point(aes(y = y))+
geom_line(aes(y = probs_0.500), colour = "red")+
geom_line(aes(y = probs_0.200), colour = "blue")+
geom_line(aes(y = probs_0.975), colour = "blue")
#> `summarise()` has grouped output by '.id'. You can override using the `.groups` argument.
Created on 2021-10-05 by the reprex package (v2.0.0)