Skip to content

Instantly share code, notes, and snippets.

@anddis
Last active October 31, 2023 14:59
Show Gist options
  • Save anddis/efe72216388db850e942ec71cdc6e3d4 to your computer and use it in GitHub Desktop.
Save anddis/efe72216388db850e942ec71cdc6e3d4 to your computer and use it in GitHub Desktop.
rstanarm::stan_surv with save_warmup = TRUE (line 736)
## rstanarm::stan_surv with save_warmup = TRUE (line 741)
## anddis 20231031
##
## See: https://github.com/stan-dev/rstanarm/issues/556
stan_surv <- function(formula,
data,
basehaz = "ms",
basehaz_ops,
qnodes = 15,
prior = normal(),
prior_intercept = normal(),
prior_aux,
prior_smooth = exponential(autoscale = FALSE),
prior_covariance = decov(),
prior_PD = FALSE,
algorithm = c("sampling", "meanfield", "fullrank"),
adapt_delta = 0.95, ...) {
#-----------------------------
# Pre-processing of arguments
#-----------------------------
if (!requireNamespace("survival"))
stop("the 'survival' package must be installed to use this function.")
if (missing(basehaz_ops))
basehaz_ops <- NULL
if (missing(data) || !inherits(data, "data.frame"))
stop("'data' must be a data frame.")
dots <- list(...)
algorithm <- match.arg(algorithm)
formula <- rstanarm:::parse_formula_and_data(formula, data)
data <- formula$data; formula[["data"]] <- NULL
#----------------
# Construct data
#----------------
#----- model frame stuff
mf_stuff <- rstanarm:::make_model_frame(formula$tf_form, data, drop.unused.levels = TRUE)
mf <- mf_stuff$mf # model frame
mt <- mf_stuff$mt # model terms
#----- dimensions and response vectors
# entry and exit times for each row of data
t_beg <- rstanarm:::make_t(mf, type = "beg") # entry time
t_end <- rstanarm:::make_t(mf, type = "end") # exit time
t_upp <- rstanarm:::make_t(mf, type = "upp") # upper time for interval censoring
# ensure no event or censoring times are zero (leads to degenerate
# estimate for log hazard for most baseline hazards, due to log(0))
check1 <- any(t_end <= 0, na.rm = TRUE)
check2 <- any(t_upp <= 0, na.rm = TRUE)
if (check1 || check2)
stop2("All event and censoring times must be greater than 0.")
# event indicator for each row of data
status <- rstanarm:::make_d(mf)
if (any(is.na(status)))
stop2("Invalid status indicator in Surv object.")
if (any(status < 0 | status > 3))
stop2("Invalid status indicator in Surv object.")
# delayed entry indicator for each row of data
delayed <- as.logical(!t_beg == 0)
# time variables for stan
t_event <- rstanarm:::aa(t_end[status == 1]) # exact event time
t_lcens <- rstanarm:::aa(t_end[status == 2]) # left censoring time
t_rcens <- rstanarm:::aa(t_end[status == 0]) # right censoring time
t_icenl <- rstanarm:::aa(t_end[status == 3]) # lower limit of interval censoring time
t_icenu <- rstanarm:::aa(t_upp[status == 3]) # upper limit of interval censoring time
t_delay <- rstanarm:::aa(t_beg[delayed]) # delayed entry time
# calculate log crude event rate
t_tmp <- sum(rowMeans(cbind(t_end, t_upp), na.rm = TRUE) - t_beg)
d_tmp <- sum(!status == 0)
log_crude_event_rate <- log(d_tmp / t_tmp)
if (is.infinite(log_crude_event_rate))
log_crude_event_rate <- 0 # avoids error when there are zero events
# dimensions
nevent <- sum(status == 1)
nrcens <- sum(status == 0)
nlcens <- sum(status == 2)
nicens <- sum(status == 3)
ndelay <- sum(delayed)
#----- baseline hazard
ok_basehaz <- c("exp",
"exp-aft",
"weibull",
"weibull-aft",
"gompertz",
"ms",
"bs")
basehaz <- rstanarm:::handle_basehaz_surv(basehaz = basehaz,
basehaz_ops = basehaz_ops,
ok_basehaz = ok_basehaz,
times = t_end,
status = status,
min_t = min(t_beg),
max_t = max(c(t_end,t_upp), na.rm = TRUE))
nvars <- basehaz$nvars # number of basehaz aux parameters
# flag if intercept is required for baseline hazard
has_intercept <- rstanarm:::ai(rstanarm:::has_intercept(basehaz))
# flag if AFT specification
is_aft <- rstanarm:::get_basehaz_name(basehaz) %in% c("exp-aft", "weibull-aft")
#----- define dimensions and times for quadrature
# flag if formula uses time-varying effects
has_tve <- !is.null(formula$td_form)
# flag if closed form available for cumulative baseline hazard
has_closed_form <- rstanarm:::check_for_closed_form(basehaz)
# flag for quadrature
has_quadrature <- has_tve || !has_closed_form
if (has_quadrature) { # model uses quadrature
# standardised nodes and weights for quadrature
qq <- rstanarm:::get_quadpoints(nodes = qnodes)
qp <- qq$points
qw <- qq$weights
# quadrature points, evaluated for each row of data
qpts_event <- rstanarm:::uapply(qp, unstandardise_qpts, 0, t_event)
qpts_lcens <- rstanarm:::uapply(qp, unstandardise_qpts, 0, t_lcens)
qpts_rcens <- rstanarm:::uapply(qp, unstandardise_qpts, 0, t_rcens)
qpts_icenl <- rstanarm:::uapply(qp, unstandardise_qpts, 0, t_icenl)
qpts_icenu <- rstanarm:::uapply(qp, unstandardise_qpts, 0, t_icenu)
qpts_delay <- rstanarm:::uapply(qp, unstandardise_qpts, 0, t_delay)
# quadrature weights, evaluated for each row of data
qwts_event <-rstanarm:::uapply(qw, unstandardise_qwts, 0, t_event)
qwts_lcens <-rstanarm:::uapply(qw, unstandardise_qwts, 0, t_lcens)
qwts_rcens <-rstanarm:::uapply(qw, unstandardise_qwts, 0, t_rcens)
qwts_icenl <-rstanarm:::uapply(qw, unstandardise_qwts, 0, t_icenl)
qwts_icenu <-rstanarm:::uapply(qw, unstandardise_qwts, 0, t_icenu)
qwts_delay <-rstanarm:::uapply(qw, unstandardise_qwts, 0, t_delay)
# times at events and all quadrature points
cpts_list <- list(t_event,
qpts_event,
qpts_lcens,
qpts_rcens,
qpts_icenl,
qpts_icenu,
qpts_delay)
idx_cpts <- rstanarm:::get_idx_array(sapply(cpts_list, length))
cpts <- unlist(cpts_list) # as vector
# number of quadrature points
qevent <- length(qwts_event)
qlcens <- length(qwts_lcens)
qrcens <- length(qwts_rcens)
qicens <- length(qwts_icenl)
qdelay <- length(qwts_delay)
} else {
# times at all different event types
cpts_list <- list(t_event,
t_lcens,
t_rcens,
t_icenl,
t_icenu,
t_delay)
idx_cpts <- rstanarm:::get_idx_array(sapply(cpts_list, length))
cpts <- unlist(cpts_list) # as vector
# dud entries for stan
qpts_event <- rep(0,0)
qpts_lcens <- rep(0,0)
qpts_rcens <- rep(0,0)
qpts_icenl <- rep(0,0)
qpts_icenu <- rep(0,0)
qpts_delay <- rep(0,0)
if (!qnodes == 15) # warn user if qnodes is not equal to the default
warning2("There is no quadrature required so 'qnodes' is being ignored.")
}
#----- basis terms for baseline hazard
if (!has_quadrature) {
basis_event <- rstanarm:::make_basis(t_event, basehaz)
ibasis_event <- rstanarm:::make_basis(t_event, basehaz, integrate = TRUE)
ibasis_lcens <- rstanarm:::make_basis(t_lcens, basehaz, integrate = TRUE)
ibasis_rcens <- rstanarm:::make_basis(t_rcens, basehaz, integrate = TRUE)
ibasis_icenl <- rstanarm:::make_basis(t_icenl, basehaz, integrate = TRUE)
ibasis_icenu <- rstanarm:::make_basis(t_icenu, basehaz, integrate = TRUE)
ibasis_delay <- rstanarm:::make_basis(t_delay, basehaz, integrate = TRUE)
} else {
basis_epts_event <- rstanarm:::make_basis(t_event, basehaz)
basis_qpts_event <- rstanarm:::make_basis(qpts_event, basehaz)
basis_qpts_lcens <- rstanarm:::make_basis(qpts_lcens, basehaz)
basis_qpts_rcens <- rstanarm:::make_basis(qpts_rcens, basehaz)
basis_qpts_icenl <- rstanarm:::make_basis(qpts_icenl, basehaz)
basis_qpts_icenu <- rstanarm:::make_basis(qpts_icenu, basehaz)
basis_qpts_delay <- rstanarm:::make_basis(qpts_delay, basehaz)
}
#----- model frames for generating predictor matrices
mf_event <-rstanarm:::keep_rows(mf, status == 1)
mf_lcens <-rstanarm:::keep_rows(mf, status == 2)
mf_rcens <-rstanarm:::keep_rows(mf, status == 0)
mf_icens <-rstanarm:::keep_rows(mf, status == 3)
mf_delay <-rstanarm:::keep_rows(mf, delayed)
if (!has_quadrature) {
# combined model frame, without quadrature
mf_cpts <- rbind(mf_event,
mf_lcens,
mf_rcens,
mf_icens,
mf_icens,
mf_delay)
} else {
# combined model frame, with quadrature
mf_cpts <- rbind(mf_event,
rstanarm:::rep_rows(mf_event, times = qnodes),
rstanarm:::rep_rows(mf_lcens, times = qnodes),
rstanarm:::rep_rows(mf_rcens, times = qnodes),
rstanarm:::rep_rows(mf_icens, times = qnodes),
rstanarm:::rep_rows(mf_icens, times = qnodes),
rstanarm:::rep_rows(mf_delay, times = qnodes))
}
if (has_tve) {
# generate a model frame with time transformations for tve effects
mf_tve <- rstanarm:::make_model_frame(formula$tt_frame, data.frame(times__ = cpts))$mf
# NB next line avoids dropping terms attribute from 'mf_cpts'
mf_cpts[, colnames(mf_tve)] <- mf_tve
}
#----- time-fixed predictor matrices
ff <- formula$fe_form
x <- rstanarm:::make_x(ff, mf )$x
x_cpts <- rstanarm:::make_x(ff, mf_cpts)$x
x_centred <- sweep(x_cpts, 2, colMeans(x), FUN = "-")
K <- ncol(x_cpts)
if (!has_quadrature) {
# time-fixed predictor matrices, without quadrature
# NB skip index 5 on purpose, since time fixed predictor matrix is
# identical for lower and upper limits of interval censoring time
x_event <- x_centred[idx_cpts[1,1]:idx_cpts[1,2], , drop = FALSE]
x_lcens <- x_centred[idx_cpts[2,1]:idx_cpts[2,2], , drop = FALSE]
x_rcens <- x_centred[idx_cpts[3,1]:idx_cpts[3,2], , drop = FALSE]
x_icens <- x_centred[idx_cpts[4,1]:idx_cpts[4,2], , drop = FALSE]
x_delay <- x_centred[idx_cpts[6,1]:idx_cpts[6,2], , drop = FALSE]
} else {
# time-fixed predictor matrices, with quadrature
# NB skip index 6 on purpose, since time fixed predictor matrix is
# identical for lower and upper limits of interval censoring time
x_epts_event <- x_centred[idx_cpts[1,1]:idx_cpts[1,2], , drop = FALSE]
x_qpts_event <- x_centred[idx_cpts[2,1]:idx_cpts[2,2], , drop = FALSE]
x_qpts_lcens <- x_centred[idx_cpts[3,1]:idx_cpts[3,2], , drop = FALSE]
x_qpts_rcens <- x_centred[idx_cpts[4,1]:idx_cpts[4,2], , drop = FALSE]
x_qpts_icens <- x_centred[idx_cpts[5,1]:idx_cpts[5,2], , drop = FALSE]
x_qpts_delay <- x_centred[idx_cpts[7,1]:idx_cpts[7,2], , drop = FALSE]
}
#----- time-varying predictor matrices
if (has_tve) {
# time-varying predictor matrix
s_cpts <- rstanarm:::make_s(formula, mf_cpts, xlevs = xlevs)
smooth_map <- rstanarm:::get_smooth_name(s_cpts, type = "smooth_map")
smooth_idx <- rstanarm:::get_idx_array(table(smooth_map))
S <- ncol(s_cpts) # number of tve coefficients
# store some additional information in model formula
# stating how many columns in the predictor matrix
# each tve() term in the model formula corresponds to
formula$tt_ncol <- attr(s_cpts, "tt_ncol")
formula$tt_map <- attr(s_cpts, "tt_map")
} else {
# dud entries if no tve() terms in model formula
s_cpts <- matrix(0,length(cpts),0)
smooth_idx <- matrix(0,0,2)
smooth_map <- integer(0)
S <- 0L
formula$tt_ncol <- integer(0)
formula$tt_map <- integer(0)
}
if (has_quadrature) {
# time-varying predictor matrices, with quadrature
s_epts_event <- s_cpts[idx_cpts[1,1]:idx_cpts[1,2], , drop = FALSE]
s_qpts_event <- s_cpts[idx_cpts[2,1]:idx_cpts[2,2], , drop = FALSE]
s_qpts_lcens <- s_cpts[idx_cpts[3,1]:idx_cpts[3,2], , drop = FALSE]
s_qpts_rcens <- s_cpts[idx_cpts[4,1]:idx_cpts[4,2], , drop = FALSE]
s_qpts_icenl <- s_cpts[idx_cpts[5,1]:idx_cpts[5,2], , drop = FALSE]
s_qpts_icenu <- s_cpts[idx_cpts[6,1]:idx_cpts[6,2], , drop = FALSE]
s_qpts_delay <- s_cpts[idx_cpts[7,1]:idx_cpts[7,2], , drop = FALSE]
}
#----- random effects predictor matrices
has_bars <- as.logical(length(formula$bars))
# use 'stan_glmer' approach
if (has_bars) {
group_unpadded <- lme4::mkReTrms(formula$bars, mf_cpts)
group <- rstanarm:::pad_reTrms(Ztlist = group_unpadded$Ztlist,
cnms = group_unpadded$cnms,
flist = group_unpadded$flist)
z_cpts <- group$Z
} else {
group <- NULL
z_cpts <- matrix(0,length(cpts),0)
}
if (!has_quadrature) {
# random effects predictor matrices, without quadrature
# NB skip index 5 on purpose, since time fixed predictor matrix is
# identical for lower and upper limits of interval censoring time
z_event <- z_cpts[idx_cpts[1,1]:idx_cpts[1,2], , drop = FALSE]
z_lcens <- z_cpts[idx_cpts[2,1]:idx_cpts[2,2], , drop = FALSE]
z_rcens <- z_cpts[idx_cpts[3,1]:idx_cpts[3,2], , drop = FALSE]
z_icens <- z_cpts[idx_cpts[4,1]:idx_cpts[4,2], , drop = FALSE]
z_delay <- z_cpts[idx_cpts[6,1]:idx_cpts[6,2], , drop = FALSE]
parts_event <- extract_sparse_parts(z_event)
parts_lcens <- extract_sparse_parts(z_lcens)
parts_rcens <- extract_sparse_parts(z_rcens)
parts_icens <- extract_sparse_parts(z_icens)
parts_delay <- extract_sparse_parts(z_delay)
} else {
# random effects predictor matrices, with quadrature
# NB skip index 6 on purpose, since time fixed predictor matrix is
# identical for lower and upper limits of interval censoring time
z_epts_event <- z_cpts[idx_cpts[1,1]:idx_cpts[1,2], , drop = FALSE]
z_qpts_event <- z_cpts[idx_cpts[2,1]:idx_cpts[2,2], , drop = FALSE]
z_qpts_lcens <- z_cpts[idx_cpts[3,1]:idx_cpts[3,2], , drop = FALSE]
z_qpts_rcens <- z_cpts[idx_cpts[4,1]:idx_cpts[4,2], , drop = FALSE]
z_qpts_icens <- z_cpts[idx_cpts[5,1]:idx_cpts[5,2], , drop = FALSE]
z_qpts_delay <- z_cpts[idx_cpts[7,1]:idx_cpts[7,2], , drop = FALSE]
parts_epts_event <- rstanarm:::extract_sparse_parts(z_epts_event)
parts_qpts_event <- rstanarm:::extract_sparse_parts(z_qpts_event)
parts_qpts_lcens <- rstanarm:::extract_sparse_parts(z_qpts_lcens)
parts_qpts_rcens <- rstanarm:::extract_sparse_parts(z_qpts_rcens)
parts_qpts_icens <- rstanarm:::extract_sparse_parts(z_qpts_icens)
parts_qpts_delay <- rstanarm:::extract_sparse_parts(z_qpts_delay)
}
#----- stan data
standata <- rstanarm:::nlist(
K, S,
nvars,
x_bar = rstanarm:::aa(colMeans(x)),
has_intercept,
has_quadrature,
smooth_map,
smooth_idx,
type = basehaz$type,
log_crude_event_rate =
ifelse(is_aft, -log_crude_event_rate, log_crude_event_rate),
nevent = if (has_quadrature) 0L else nevent,
nlcens = if (has_quadrature) 0L else nlcens,
nrcens = if (has_quadrature) 0L else nrcens,
nicens = if (has_quadrature) 0L else nicens,
ndelay = if (has_quadrature) 0L else ndelay,
t_event = if (has_quadrature) rep(0,0) else t_event,
t_lcens = if (has_quadrature) rep(0,0) else t_lcens,
t_rcens = if (has_quadrature) rep(0,0) else t_rcens,
t_icenl = if (has_quadrature) rep(0,0) else t_icenl,
t_icenu = if (has_quadrature) rep(0,0) else t_icenu,
t_delay = if (has_quadrature) rep(0,0) else t_delay,
x_event = if (has_quadrature) matrix(0,0,K) else x_event,
x_lcens = if (has_quadrature) matrix(0,0,K) else x_lcens,
x_rcens = if (has_quadrature) matrix(0,0,K) else x_rcens,
x_icens = if (has_quadrature) matrix(0,0,K) else x_icens,
x_delay = if (has_quadrature) matrix(0,0,K) else x_delay,
w_event = if (has_quadrature || !has_bars || nevent == 0) double(0) else parts_event$w,
w_lcens = if (has_quadrature || !has_bars || nlcens == 0) double(0) else parts_lcens$w,
w_rcens = if (has_quadrature || !has_bars || nrcens == 0) double(0) else parts_rcens$w,
w_icens = if (has_quadrature || !has_bars || nicens == 0) double(0) else parts_icens$w,
w_delay = if (has_quadrature || !has_bars || ndelay == 0) double(0) else parts_delay$w,
v_event = if (has_quadrature || !has_bars || nevent == 0) integer(0) else parts_event$v - 1L,
v_lcens = if (has_quadrature || !has_bars || nlcens == 0) integer(0) else parts_lcens$v - 1L,
v_rcens = if (has_quadrature || !has_bars || nrcens == 0) integer(0) else parts_rcens$v - 1L,
v_icens = if (has_quadrature || !has_bars || nicens == 0) integer(0) else parts_icens$v - 1L,
v_delay = if (has_quadrature || !has_bars || ndelay == 0) integer(0) else parts_delay$v - 1L,
u_event = if (has_quadrature || !has_bars || nevent == 0) integer(0) else parts_event$u - 1L,
u_lcens = if (has_quadrature || !has_bars || nlcens == 0) integer(0) else parts_lcens$u - 1L,
u_rcens = if (has_quadrature || !has_bars || nrcens == 0) integer(0) else parts_rcens$u - 1L,
u_icens = if (has_quadrature || !has_bars || nicens == 0) integer(0) else parts_icens$u - 1L,
u_delay = if (has_quadrature || !has_bars || ndelay == 0) integer(0) else parts_delay$u - 1L,
nnz_event = if (has_quadrature || !has_bars || nevent == 0) 0L else length(parts_event$w),
nnz_lcens = if (has_quadrature || !has_bars || nlcens == 0) 0L else length(parts_lcens$w),
nnz_rcens = if (has_quadrature || !has_bars || nrcens == 0) 0L else length(parts_rcens$w),
nnz_icens = if (has_quadrature || !has_bars || nicens == 0) 0L else length(parts_icens$w),
nnz_delay = if (has_quadrature || !has_bars || ndelay == 0) 0L else length(parts_delay$w),
basis_event = if (has_quadrature) matrix(0,0,nvars) else basis_event,
ibasis_event = if (has_quadrature) matrix(0,0,nvars) else ibasis_event,
ibasis_lcens = if (has_quadrature) matrix(0,0,nvars) else ibasis_lcens,
ibasis_rcens = if (has_quadrature) matrix(0,0,nvars) else ibasis_rcens,
ibasis_icenl = if (has_quadrature) matrix(0,0,nvars) else ibasis_icenl,
ibasis_icenu = if (has_quadrature) matrix(0,0,nvars) else ibasis_icenu,
ibasis_delay = if (has_quadrature) matrix(0,0,nvars) else ibasis_delay,
qnodes = if (!has_quadrature) 0L else qnodes,
Nevent = if (!has_quadrature) 0L else nevent,
Nlcens = if (!has_quadrature) 0L else nlcens,
Nrcens = if (!has_quadrature) 0L else nrcens,
Nicens = if (!has_quadrature) 0L else nicens,
Ndelay = if (!has_quadrature) 0L else ndelay,
qevent = if (!has_quadrature) 0L else qevent,
qlcens = if (!has_quadrature) 0L else qlcens,
qrcens = if (!has_quadrature) 0L else qrcens,
qicens = if (!has_quadrature) 0L else qicens,
qdelay = if (!has_quadrature) 0L else qdelay,
epts_event = if (!has_quadrature) rep(0,0) else t_event,
qpts_event = if (!has_quadrature) rep(0,0) else qpts_event,
qpts_lcens = if (!has_quadrature) rep(0,0) else qpts_lcens,
qpts_rcens = if (!has_quadrature) rep(0,0) else qpts_rcens,
qpts_icenl = if (!has_quadrature) rep(0,0) else qpts_icenl,
qpts_icenu = if (!has_quadrature) rep(0,0) else qpts_icenu,
qpts_delay = if (!has_quadrature) rep(0,0) else qpts_delay,
qwts_event = if (!has_quadrature) rep(0,0) else qwts_event,
qwts_lcens = if (!has_quadrature) rep(0,0) else qwts_lcens,
qwts_rcens = if (!has_quadrature) rep(0,0) else qwts_rcens,
qwts_icenl = if (!has_quadrature) rep(0,0) else qwts_icenl,
qwts_icenu = if (!has_quadrature) rep(0,0) else qwts_icenu,
qwts_delay = if (!has_quadrature) rep(0,0) else qwts_delay,
x_epts_event = if (!has_quadrature) matrix(0,0,K) else x_epts_event,
x_qpts_event = if (!has_quadrature) matrix(0,0,K) else x_qpts_event,
x_qpts_lcens = if (!has_quadrature) matrix(0,0,K) else x_qpts_lcens,
x_qpts_rcens = if (!has_quadrature) matrix(0,0,K) else x_qpts_rcens,
x_qpts_icens = if (!has_quadrature) matrix(0,0,K) else x_qpts_icens,
x_qpts_delay = if (!has_quadrature) matrix(0,0,K) else x_qpts_delay,
s_epts_event = if (!has_quadrature) matrix(0,0,S) else s_epts_event,
s_qpts_event = if (!has_quadrature) matrix(0,0,S) else s_qpts_event,
s_qpts_lcens = if (!has_quadrature) matrix(0,0,S) else s_qpts_lcens,
s_qpts_rcens = if (!has_quadrature) matrix(0,0,S) else s_qpts_rcens,
s_qpts_icenl = if (!has_quadrature) matrix(0,0,S) else s_qpts_icenl,
s_qpts_icenu = if (!has_quadrature) matrix(0,0,S) else s_qpts_icenu,
s_qpts_delay = if (!has_quadrature) matrix(0,0,S) else s_qpts_delay,
w_epts_event = if (!has_quadrature || !has_bars || qevent == 0) double(0) else parts_epts_event$w,
w_qpts_event = if (!has_quadrature || !has_bars || qevent == 0) double(0) else parts_qpts_event$w,
w_qpts_lcens = if (!has_quadrature || !has_bars || qlcens == 0) double(0) else parts_qpts_lcens$w,
w_qpts_rcens = if (!has_quadrature || !has_bars || qrcens == 0) double(0) else parts_qpts_rcens$w,
w_qpts_icens = if (!has_quadrature || !has_bars || qicens == 0) double(0) else parts_qpts_icens$w,
w_qpts_delay = if (!has_quadrature || !has_bars || qdelay == 0) double(0) else parts_qpts_delay$w,
v_epts_event = if (!has_quadrature || !has_bars || qevent == 0) integer(0) else parts_epts_event$v - 1L,
v_qpts_event = if (!has_quadrature || !has_bars || qevent == 0) integer(0) else parts_qpts_event$v - 1L,
v_qpts_lcens = if (!has_quadrature || !has_bars || qlcens == 0) integer(0) else parts_qpts_lcens$v - 1L,
v_qpts_rcens = if (!has_quadrature || !has_bars || qrcens == 0) integer(0) else parts_qpts_rcens$v - 1L,
v_qpts_icens = if (!has_quadrature || !has_bars || qicens == 0) integer(0) else parts_qpts_icens$v - 1L,
v_qpts_delay = if (!has_quadrature || !has_bars || qdelay == 0) integer(0) else parts_qpts_delay$v - 1L,
u_epts_event = if (!has_quadrature || !has_bars || qevent == 0) integer(0) else parts_epts_event$u - 1L,
u_qpts_event = if (!has_quadrature || !has_bars || qevent == 0) integer(0) else parts_qpts_event$u - 1L,
u_qpts_lcens = if (!has_quadrature || !has_bars || qlcens == 0) integer(0) else parts_qpts_lcens$u - 1L,
u_qpts_rcens = if (!has_quadrature || !has_bars || qrcens == 0) integer(0) else parts_qpts_rcens$u - 1L,
u_qpts_icens = if (!has_quadrature || !has_bars || qicens == 0) integer(0) else parts_qpts_icens$u - 1L,
u_qpts_delay = if (!has_quadrature || !has_bars || qdelay == 0) integer(0) else parts_qpts_delay$u - 1L,
nnz_epts_event = if (!has_quadrature || !has_bars || qevent == 0) 0L else length(parts_epts_event$w),
nnz_qpts_event = if (!has_quadrature || !has_bars || qevent == 0) 0L else length(parts_qpts_event$w),
nnz_qpts_lcens = if (!has_quadrature || !has_bars || qlcens == 0) 0L else length(parts_qpts_lcens$w),
nnz_qpts_rcens = if (!has_quadrature || !has_bars || qrcens == 0) 0L else length(parts_qpts_rcens$w),
nnz_qpts_icens = if (!has_quadrature || !has_bars || qicens == 0) 0L else length(parts_qpts_icens$w),
nnz_qpts_delay = if (!has_quadrature || !has_bars || qdelay == 0) 0L else length(parts_qpts_delay$w),
basis_epts_event = if (!has_quadrature) matrix(0,0,nvars) else basis_epts_event,
basis_qpts_event = if (!has_quadrature) matrix(0,0,nvars) else basis_qpts_event,
basis_qpts_lcens = if (!has_quadrature) matrix(0,0,nvars) else basis_qpts_lcens,
basis_qpts_rcens = if (!has_quadrature) matrix(0,0,nvars) else basis_qpts_rcens,
basis_qpts_icenl = if (!has_quadrature) matrix(0,0,nvars) else basis_qpts_icenl,
basis_qpts_icenu = if (!has_quadrature) matrix(0,0,nvars) else basis_qpts_icenu,
basis_qpts_delay = if (!has_quadrature) matrix(0,0,nvars) else basis_qpts_delay
)
#----- random-effects structure
if (has_bars) {
fl <- group$flist
p <- sapply(group$cnms, FUN = length)
l <- sapply(attr(fl, "assign"), function(i) nlevels(fl[[i]]))
t <- length(l)
standata$p <- as.array(p) # num ranefs for each grouping factor
standata$l <- as.array(l) # num levels for each grouping factor
standata$t <- t # num of grouping factors
standata$q <- ncol(group$Z) # p * l
standata$special_case <- all(sapply(group$cnms, intercept_only))
} else { # no random effects structure
standata$p <- integer(0)
standata$l <- integer(0)
standata$t <- 0L
standata$q <- 0L
standata$special_case <- 0L
}
#----- priors and hyperparameters
# valid priors
ok_dists <- nlist("normal",
student_t = "t",
"cauchy",
"hs",
"hs_plus",
"laplace",
"lasso") # disallow product normal
ok_intercept_dists <- ok_dists[1:3]
ok_aux_dists <- rstanarm:::get_ok_priors_for_aux(basehaz)
ok_smooth_dists <- c(ok_dists[1:3], "exponential")
ok_covariance_dists <- c("decov")
if (missing(prior_aux))
prior_aux <- rstanarm:::get_default_prior_for_aux(basehaz)
# priors
user_prior_stuff <- prior_stuff <-
rstanarm:::handle_glm_prior(prior,
nvars = K,
default_scale = 2.5,
link = NULL,
ok_dists = ok_dists)
user_prior_intercept_stuff <- prior_intercept_stuff <-
rstanarm:::handle_glm_prior(prior_intercept,
nvars = 1,
default_scale = 20,
link = NULL,
ok_dists = ok_intercept_dists)
user_prior_aux_stuff <- prior_aux_stuff <-
rstanarm:::handle_glm_prior(prior_aux,
nvars = basehaz$nvars,
default_scale = rstanarm:::get_default_aux_scale(basehaz),
link = NULL,
ok_dists = ok_aux_dists)
user_prior_smooth_stuff <- prior_smooth_stuff <-
rstanarm:::handle_glm_prior(prior_smooth,
nvars = if (S) max(smooth_map) else 0,
default_scale = 1,
link = NULL,
ok_dists = ok_smooth_dists)
# stop null priors when prior_PD is true
if (prior_PD) {
if (is.null(prior))
stop("'prior' cannot be NULL if 'prior_PD' is TRUE.")
if (is.null(prior_intercept) && has_intercept)
stop("'prior_intercept' cannot be NULL if 'prior_PD' is TRUE.")
if (is.null(prior_aux))
stop("'prior_aux' cannot be NULL if 'prior_PD' is TRUE.")
if (is.null(prior_smooth) && (S > 0))
stop("'prior_smooth' cannot be NULL if 'prior_PD' is TRUE.")
}
# handle prior for random effects structure
if (has_bars) {
user_prior_b_stuff <- prior_b_stuff <-
rstanarm:::handle_cov_prior(prior_covariance,
cnms = group$cnms,
ok_dists = ok_covariance_dists)
if (is.null(prior_covariance))
stop("'prior_covariance' cannot be NULL.")
} else {
user_prior_b_stuff <- NULL
prior_b_stuff <- NULL
prior_covariance <- NULL
}
# autoscaling of priors
prior_stuff <- rstanarm:::autoscale_prior(prior_stuff, predictors = x)
prior_intercept_stuff <- rstanarm:::autoscale_prior(prior_intercept_stuff)
prior_aux_stuff <- rstanarm:::autoscale_prior(prior_aux_stuff)
prior_smooth_stuff <- rstanarm:::autoscale_prior(prior_smooth_stuff)
# priors
standata$prior_dist <- prior_stuff$prior_dist
standata$prior_dist_for_intercept<- prior_intercept_stuff$prior_dist
standata$prior_dist_for_aux <- prior_aux_stuff$prior_dist
standata$prior_dist_for_smooth <- prior_smooth_stuff$prior_dist
standata$prior_dist_for_cov <- prior_b_stuff$prior_dist
# hyperparameters
standata$prior_mean <- prior_stuff$prior_mean
standata$prior_scale <- prior_stuff$prior_scale
standata$prior_df <- prior_stuff$prior_df
standata$prior_mean_for_intercept <- c(prior_intercept_stuff$prior_mean)
standata$prior_scale_for_intercept<- c(prior_intercept_stuff$prior_scale)
standata$prior_df_for_intercept <- c(prior_intercept_stuff$prior_df)
standata$prior_scale_for_aux <- prior_aux_stuff$prior_scale
standata$prior_df_for_aux <- prior_aux_stuff$prior_df
standata$prior_conc_for_aux <- prior_aux_stuff$prior_concentration
standata$prior_mean_for_smooth <- prior_smooth_stuff$prior_mean
standata$prior_scale_for_smooth <- prior_smooth_stuff$prior_scale
standata$prior_df_for_smooth <- prior_smooth_stuff$prior_df
standata$global_prior_scale <- prior_stuff$global_prior_scale
standata$global_prior_df <- prior_stuff$global_prior_df
standata$slab_df <- prior_stuff$slab_df
standata$slab_scale <- prior_stuff$slab_scale
# hyperparameters for covariance
if (has_bars) {
standata$b_prior_shape <- prior_b_stuff$prior_shape
standata$b_prior_scale <- prior_b_stuff$prior_scale
standata$concentration <- prior_b_stuff$prior_concentration
standata$regularization <- prior_b_stuff$prior_regularization
standata$len_concentration <- length(standata$concentration)
standata$len_regularization <- length(standata$regularization)
standata$len_theta_L <- sum(choose(standata$p, 2), standata$p)
} else { # no random effects structure
standata$b_prior_shape <- rep(0, 0)
standata$b_prior_scale <- rep(0, 0)
standata$concentration <- rep(0, 0)
standata$regularization <- rep(0, 0)
standata$len_concentration <- 0L
standata$len_regularization <- 0L
standata$len_theta_L <- 0L
}
# any additional flags
standata$prior_PD <- rstanarm:::ai(prior_PD)
#---------------
# Prior summary
#---------------
prior_info <- rstanarm:::summarize_jm_prior(
user_priorEvent = user_prior_stuff,
user_priorEvent_intercept = user_prior_intercept_stuff,
user_priorEvent_aux = user_prior_aux_stuff,
adjusted_priorEvent_scale = prior_stuff$prior_scale,
adjusted_priorEvent_intercept_scale = prior_intercept_stuff$prior_scale,
adjusted_priorEvent_aux_scale = prior_aux_stuff$prior_scale,
e_has_intercept = has_intercept,
e_has_predictors = K > 0,
basehaz = basehaz,
user_prior_covariance = prior_covariance,
b_user_prior_stuff = user_prior_b_stuff,
b_prior_stuff = prior_b_stuff
)
#-----------
# Fit model
#-----------
# obtain stan model code
stanfit <- rstanarm:::stanmodels$surv
# specify parameters for stan to monitor
stanpars <- c(if (standata$has_intercept) "alpha",
if (standata$K) "beta",
if (standata$S) "beta_tve",
if (standata$S) "smooth_sd",
if (standata$nvars) "aux",
if (standata$t) "b",
if (standata$t) "theta_L")
# fit model using stan
if (algorithm == "sampling") { # mcmc
args <- rstanarm:::set_sampling_args(
object = stanfit,
data = standata,
pars = stanpars,
prior = prior,
user_dots = list(...),
user_adapt_delta = adapt_delta,
show_messages = FALSE)
args[["save_warmup"]] <- TRUE
stanfit <- do.call(rstan::sampling, args)
} else { # meanfield or fullrank vb
args <- rstanarm:::nlist(
object = stanfit,
data = standata,
pars = stanpars,
algorithm
)
args[names(dots)] <- dots
stanfit <- do.call(rstan::vb, args)
}
rstanarm:::check_stanfit(stanfit)
# replace 'theta_L' with the variance-covariance matrix
if (has_bars)
stanfit <- rstanarm:::evaluate_Sigma(stanfit, group$cnms)
# define new parameter names
nms_beta <- colnames(x_cpts) # may be NULL
nms_tve <- rstanarm:::get_smooth_name(s_cpts, type = "smooth_coefs") # may be NULL
nms_smooth <- rstanarm:::get_smooth_name(s_cpts, type = "smooth_sd") # may be NULL
nms_int <- rstanarm:::get_int_name_basehaz(basehaz)
nms_aux <- rstanarm:::get_aux_name_basehaz(basehaz)
nms_b <- rstanarm:::get_b_names(group) # may be NULL
nms_vc <- rstanarm:::get_varcov_names(group) # may be NULL
nms_all <- c(nms_int,
nms_beta,
nms_tve,
nms_smooth,
nms_aux,
nms_b,
nms_vc,
"log-posterior")
# substitute new parameter names into 'stanfit' object
stanfit <- rstanarm:::replace_stanfit_nms(stanfit, nms_all)
# return an object of class 'stansurv'
fit <- rstanarm:::nlist(stanfit,
formula,
has_tve,
has_quadrature,
has_bars,
data,
model_frame = mf,
terms = mt,
xlevels = .getXlevels(mt, mf),
x,
x_cpts,
s_cpts = if (has_tve) s_cpts else NULL,
z_cpts = if (has_bars) z_cpts else NULL,
cnms = if (has_bars) group_unpadded$cnms else NULL,
flist = if (has_bars) group_unpadded$flist else NULL,
t_beg,
t_end,
status,
event = as.logical(status == 1),
delayed,
basehaz,
nobs = nrow(mf),
nevents = nevent,
nlcens,
nrcens,
nicens,
ncensor = nlcens + nrcens + nicens,
ndelayed = ndelay,
prior_info,
qnodes = if (has_quadrature) qnodes else NULL,
algorithm,
stan_function = "stan_surv",
rstanarm_version = utils::packageVersion("rstanarm"),
call = match.call(expand.dots = TRUE))
rstanarm:::stansurv(fit)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment