Skip to content

Instantly share code, notes, and snippets.

@chrishanretty
Created April 4, 2024 11:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chrishanretty/ba0b46b767a169fa1f82ddfe69246f71 to your computer and use it in GitHub Desktop.
Save chrishanretty/ba0b46b767a169fa1f82ddfe69246f71 to your computer and use it in GitHub Desktop.
Predictions from predict.bam and marginaleffects::predictions with and without discretization
### PURPOSE OF THIS CODE: estimate a large model in bam and speed-test
### predictions using mgcv::predict.bam and
### marginaleffects::predictions with and without discretization and
### parallelization
### ##################################################################
### load libraries
### ##################################################################
library(mgcv)
library(marginaleffects)
library(nycflights13)
library(tictoc)
data("flights")
### Some small config
my_threads <- 12
set.seed(3)
### ##################################################################
### transform data to help in modelling departure delay
### ##################################################################
## Handle date: convert to a numeric so we can use it as a smooth term
flights <- flights |>
transform(date = as.Date(paste(year, month, day, sep = "/"))) |>
transform(date.num = as.numeric(date - min(date)))
### Handle wday: convert to numeric
flights <- flights |>
transform(wday = as.POSIXlt(date)$wday)
## Handle time of departure, again convering to numeric
flights <- flights |>
transform(time = as.POSIXct(paste(hour, minute, sep = ":"), format = "%H:%M")) |>
transform(time.dt = difftime(time,
as.POSIXct('00:00', format = '%H:%M'), units = 'min')) |>
transform(time.num = as.numeric(time.dt))
### Handle the outcome, specifically early and missing data
flights <- flights |>
transform(dep_delay = ifelse(dep_delay < 0, 0, dep_delay)) |>
transform(dep_delay = ifelse(is.na(dep_delay), 0, dep_delay))
### Transform some things to factors
flights <- flights |>
transform(carrier = factor(carrier)) |>
transform(dest = factor(dest)) |>
transform(origin = factor(origin))
### ##################################################################
### Estimate models with and without discretization
### ##################################################################
### Takes about 135 seconds single-threaded
tic()
m_base <- bam(dep_delay ~ s(date.num, bs = "cr") +
s(wday, bs = "cc", k = 3) +
s(time.num, bs = "cr") +
s(carrier, bs = "re") +
origin +
s(distance, bs = "cr") +
s(dest, bs = "re"),
data = flights,
family = poisson,
discrete = FALSE)
toc()
### Takes about 8 seconds
tic()
m_discrete <- bam(dep_delay ~ s(date.num, bs = "cr") +
s(wday, bs = "cc", k = 3) +
s(time.num, bs = "cr") +
s(carrier, bs = "re") +
origin +
s(distance, bs = "cr") +
s(dest, bs = "re"),
data = flights,
family = poisson,
discrete = TRUE,
nthreads = my_threads)
toc()
### ##################################################################
### generate predictions, w/ and w/o SEs, w/ and w/o discretization
### ##################################################################
### Case 1: mgcv, w/o SEs, w/o discretization
### takes around 3 seconds
tic()
p1 <- predict(m_base, se.fit = FALSE)
tot <- toc()
e_1 <- tot$toc - tot$tic
### Case 2: marginaleffects, w/o SEs, w/o discretization
### takes around 3 seconds
tic()
p2 <- predictions(m_base, vcov = FALSE)
tot <- toc()
e_2 <- tot$toc - tot$tic
### Case 3: mgcv, w/ SEs, w/o discretization
### takes around 17 seconds
tic()
p3 <- predict(m_base, se.fit = TRUE)
tot <- toc()
e_3 <- tot$toc - tot$tic
### Case 4: marginaleffects, w/ SEs, w/o discretization
### takes around 356 sceonds
tic()
p4 <- predictions(m_base, vcov = TRUE)
tot <- toc()
e_4 <- tot$toc - tot$tic
### Case 5: mgcv, w/o SEs, w/ discretization
### takes around 1/3rd of a second
tic()
p5 <- predict(m_discrete, se.fit = FALSE,
discrete = TRUE, nthreads = my_threads)
tot <- toc()
e_5 <- tot$toc - tot$tic
### Case 6: marginaleffects, w/o SEs, w/ discretization
### takes around 0.8 seconds
tic()
p6 <- predictions(m_discrete, vcov = FALSE,
discrete = TRUE, nthreads = my_threads)
tot <- toc()
e_6 <- tot$toc - tot$tic
### Case 7: mgcv, w/ SEs, w/ discretization
### takes around 3.6 seconds
tic()
p7 <- predict(m_discrete, se.fit = TRUE,
discrete = TRUE, nthreads = my_threads)
tot <- toc()
e_7 <- tot$toc - tot$tic
### Case 8: marginaleffects, w/ SEs, w/ discretization
### takes around 82 seconds
tic()
p8 <- predictions(m_discrete, vcov = TRUE,
discrete = TRUE, nthreads = my_threads)
tot <- toc()
e_8 <- tot$toc - tot$tic
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment