Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save sbfnk/569ad82641d73286a355317e53d911cf to your computer and use it in GitHub Desktop.
Save sbfnk/569ad82641d73286a355317e53d911cf to your computer and use it in GitHub Desktop.
This gist shows how to estimate a doubly censored (i.e daily data) and right truncated (i.e due to epidemic phase) distribution using the brms package.
# Load packages
library(brms)
library(cmdstanr)
library(data.table) # here we use the development version of data.table install it with data.table::update_dev_pkg
library(purrr)
library(bpmodels) # devtools::install_github("sbfnk/bpmodels")
# Set up parallel cores
options(mc.cores = 4)
# Simulate some truncated and truncation data
init_cases <- 500 ## 500 cases at time 0 (which is not the same as 500 daily cases in an ongoing epidemic)
growth_rate <- 0.1
max_t <- 20
samples <- 400
# note we actually won't end up with this many samples as some will be truncated
logmean <- 1.6
logsd <- 0.6
# Simulate the underlying outbreak structure assuming a gamma distributed generation time
## reproduction number
R <- 1.5
## serial interval CV
tg_kappa <- 0.5
## serial interval
tg_mean <- (exp(tg_kappa * log(R)) - 1) / (tg_kappa * growth_rate)
## gamma distribution parameters
tg_shape <- 1 / tg_kappa
tg_scale <- tg_mean / tg_shape
tg <- function(n) {
rgamma(n, shape = tg_shape, scale = tg_scale)
}
linelist <- bpmodels::chain_sim(
init_cases, "pois", tf = max_t, serial = tg, lambda = R
) |>
data.table() |>
DT(time > 0, list(time, id = 1:.N))
cases <- linelist |>
DT(, .(time = as.integer(floor(time)))) |>
DT(, .(cases = .N), by = time) |>
setorder(time)
plot(cases$cases)
# Simulate the observation process for the line list
obs <- data.table(
time = sample(linelist$time, samples, replace = FALSE),
delay = rlnorm(samples, logmean, logsd)
) |>
# Add a new ID
DT(, id := 1:.N) |>
# When would data be observed
DT(, obs_delay := time + delay) |>
# Integerise delay
DT(, daily_delay := floor(obs_delay) - floor(time)) |>
# Day before observations
DT(, daily_delay_m1 := pmax(daily_delay - 1, 0)) |>
# Day after observations
DT(, daily_delay_p1 := daily_delay + 1) |>
# Time observe for (including last day)
DT(, obs_time := max_t + 1 - floor(time)) |>
# We don't know this exactly so need to censor
# Set to the midday point as average across day
DT(, censored_obs_time := obs_time - 0.5) |>
DT(, censored := "interval")
# Make event based data for latent modelling
obs <- obs |>
DT(, ptime := floor(time)) |>
DT(, stime := floor(obs_delay)) |>
DT(, max_t := max_t)
# Truncate observations
truncated_obs <- obs |>
DT(stime <= max_t)
double_truncated_obs <- truncated_obs |>
# The lognormal family in brms does not support 0 so also truncate delays > 1
# This seems like it could be improved
DT(daily_delay_m1 >= 1)
# Fit lognormal model with no corrections
naive_model <- brm(
bf(daily_delay ~ 1, sigma ~ 1), data = truncated_obs,
family = lognormal(), backend = "cmdstanr", adapt_delta = 0.9
)
# We see that the log mean is truncated
# the sigma_intercept needs to be exponentiated to return the log sd
summary(naive_model)
double_truncated_obs_trunc <- truncated_obs |>
# The lognormal family in brms does not support 0 so also truncate delays > 1
# This seems like it could be improved
DT(censored_obs_time >= 1)
# Adjust for truncation
trunc_model <- brm(
bf(daily_delay | trunc(lb = 1, ub = censored_obs_time) ~ 1, sigma ~ 1),
data = double_truncated_obs_trunc, family = lognormal(),
backend = "cmdstanr", adapt_delta = 0.9
)
# Getting closer to recovering our simulated estimates
summary(trunc_model)
# Correct for censoring
censor_model <- brm(
bf(daily_delay_m1 | cens(censored, daily_delay_p1) ~ 1, sigma ~ 1),
data = double_truncated_obs, family = lognormal(),
backend = "cmdstanr", adapt_delta = 0.9
)
# Less close than truncation but better than naive model
summary(censor_model)
# Correct for double interval censoring and truncation
censor_trunc_model <- brm(
bf(
daily_delay_m1 | trunc(lb = 1, ub = censored_obs_time) +
cens(censored, daily_delay_p1) ~ 1,
sigma ~ 1
),
data = double_truncated_obs, family = lognormal(), backend = "cmdstanr"
)
# Recover underlying distribution
# As the growth rate increases and with short delays we may still see a bias
# as we have a censored observation time
summary(censor_trunc_model)
# Model censoring as a latent process (WIP)
# For this model we need to use a custom brms family and so
# the code is significantly more complex.
# Custom family for latent censoring and truncation
fit_latent_lognormal <- function(fn = brm, ...) {
latent_lognormal <- custom_family(
"latent_lognormal",
dpars = c("mu", "sigma", "pwindow", "swindow"),
links = c("identity", "log", "identity", "identity"),
lb = c(NA, 0, 0, 0),
ub = c(NA, NA, 1, 1),
type = "real",
vars = c("vreal1[n]", "vreal2[n]")
)
stan_funs <- "
real latent_lognormal_lpdf(real y, real mu, real sigma, real pwindow,
real swindow, real sevent,
real end_t) {
real p = y + pwindow;
real s = sevent + swindow;
real d = s - p;
real obs_time = end_t - p;
return lognormal_lpdf(d | mu, sigma) - lognormal_lcdf(obs_time | mu, sigma);
}
"
stanvars <- stanvar(block = "functions", scode = stan_funs)
# Set up shared priors ----------------------------------------------------
priors <- c(
prior(uniform(0, 1), class = "b", dpar = "pwindow", lb = 0, ub = 1),
prior(uniform(0, 1), class = "b", dpar = "swindow", lb = 0, ub = 1)
)
fit <- fn(family = latent_lognormal, stanvars = stanvars, prior = priors, ...)
return(fit)
}
# Fit latent lognormal model
latent_model <- fit_latent_lognormal(
bf(primary_event | vreal(secondary_event, max_t) ~ 1, sigma ~ 1,
pwindow ~ 0 + as.factor(id), swindow ~ 0 + as.factor(id)),
data = truncated_obs, backend = "cmdstanr", fn = brm,
adapt_delta = 0.95
)
# Should also see parameter recovery using this method though
# run-times are much higher and the model is somewhat unstable.
summary(latent_model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment