Last active
October 31, 2023 14:59
-
-
Save anddis/efe72216388db850e942ec71cdc6e3d4 to your computer and use it in GitHub Desktop.
rstanarm::stan_surv with save_warmup = TRUE (line 736)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## rstanarm::stan_surv with save_warmup = TRUE (line 741) | |
## anddis 20231031 | |
## | |
## See: https://github.com/stan-dev/rstanarm/issues/556 | |
stan_surv <- function(formula, | |
data, | |
basehaz = "ms", | |
basehaz_ops, | |
qnodes = 15, | |
prior = normal(), | |
prior_intercept = normal(), | |
prior_aux, | |
prior_smooth = exponential(autoscale = FALSE), | |
prior_covariance = decov(), | |
prior_PD = FALSE, | |
algorithm = c("sampling", "meanfield", "fullrank"), | |
adapt_delta = 0.95, ...) { | |
#----------------------------- | |
# Pre-processing of arguments | |
#----------------------------- | |
if (!requireNamespace("survival")) | |
stop("the 'survival' package must be installed to use this function.") | |
if (missing(basehaz_ops)) | |
basehaz_ops <- NULL | |
if (missing(data) || !inherits(data, "data.frame")) | |
stop("'data' must be a data frame.") | |
dots <- list(...) | |
algorithm <- match.arg(algorithm) | |
formula <- rstanarm:::parse_formula_and_data(formula, data) | |
data <- formula$data; formula[["data"]] <- NULL | |
#---------------- | |
# Construct data | |
#---------------- | |
#----- model frame stuff | |
mf_stuff <- rstanarm:::make_model_frame(formula$tf_form, data, drop.unused.levels = TRUE) | |
mf <- mf_stuff$mf # model frame | |
mt <- mf_stuff$mt # model terms | |
#----- dimensions and response vectors | |
# entry and exit times for each row of data | |
t_beg <- rstanarm:::make_t(mf, type = "beg") # entry time | |
t_end <- rstanarm:::make_t(mf, type = "end") # exit time | |
t_upp <- rstanarm:::make_t(mf, type = "upp") # upper time for interval censoring | |
# ensure no event or censoring times are zero (leads to degenerate | |
# estimate for log hazard for most baseline hazards, due to log(0)) | |
check1 <- any(t_end <= 0, na.rm = TRUE) | |
check2 <- any(t_upp <= 0, na.rm = TRUE) | |
if (check1 || check2) | |
stop2("All event and censoring times must be greater than 0.") | |
# event indicator for each row of data | |
status <- rstanarm:::make_d(mf) | |
if (any(is.na(status))) | |
stop2("Invalid status indicator in Surv object.") | |
if (any(status < 0 | status > 3)) | |
stop2("Invalid status indicator in Surv object.") | |
# delayed entry indicator for each row of data | |
delayed <- as.logical(!t_beg == 0) | |
# time variables for stan | |
t_event <- rstanarm:::aa(t_end[status == 1]) # exact event time | |
t_lcens <- rstanarm:::aa(t_end[status == 2]) # left censoring time | |
t_rcens <- rstanarm:::aa(t_end[status == 0]) # right censoring time | |
t_icenl <- rstanarm:::aa(t_end[status == 3]) # lower limit of interval censoring time | |
t_icenu <- rstanarm:::aa(t_upp[status == 3]) # upper limit of interval censoring time | |
t_delay <- rstanarm:::aa(t_beg[delayed]) # delayed entry time | |
# calculate log crude event rate | |
t_tmp <- sum(rowMeans(cbind(t_end, t_upp), na.rm = TRUE) - t_beg) | |
d_tmp <- sum(!status == 0) | |
log_crude_event_rate <- log(d_tmp / t_tmp) | |
if (is.infinite(log_crude_event_rate)) | |
log_crude_event_rate <- 0 # avoids error when there are zero events | |
# dimensions | |
nevent <- sum(status == 1) | |
nrcens <- sum(status == 0) | |
nlcens <- sum(status == 2) | |
nicens <- sum(status == 3) | |
ndelay <- sum(delayed) | |
#----- baseline hazard | |
ok_basehaz <- c("exp", | |
"exp-aft", | |
"weibull", | |
"weibull-aft", | |
"gompertz", | |
"ms", | |
"bs") | |
basehaz <- rstanarm:::handle_basehaz_surv(basehaz = basehaz, | |
basehaz_ops = basehaz_ops, | |
ok_basehaz = ok_basehaz, | |
times = t_end, | |
status = status, | |
min_t = min(t_beg), | |
max_t = max(c(t_end,t_upp), na.rm = TRUE)) | |
nvars <- basehaz$nvars # number of basehaz aux parameters | |
# flag if intercept is required for baseline hazard | |
has_intercept <- rstanarm:::ai(rstanarm:::has_intercept(basehaz)) | |
# flag if AFT specification | |
is_aft <- rstanarm:::get_basehaz_name(basehaz) %in% c("exp-aft", "weibull-aft") | |
#----- define dimensions and times for quadrature | |
# flag if formula uses time-varying effects | |
has_tve <- !is.null(formula$td_form) | |
# flag if closed form available for cumulative baseline hazard | |
has_closed_form <- rstanarm:::check_for_closed_form(basehaz) | |
# flag for quadrature | |
has_quadrature <- has_tve || !has_closed_form | |
if (has_quadrature) { # model uses quadrature | |
# standardised nodes and weights for quadrature | |
qq <- rstanarm:::get_quadpoints(nodes = qnodes) | |
qp <- qq$points | |
qw <- qq$weights | |
# quadrature points, evaluated for each row of data | |
qpts_event <- rstanarm:::uapply(qp, unstandardise_qpts, 0, t_event) | |
qpts_lcens <- rstanarm:::uapply(qp, unstandardise_qpts, 0, t_lcens) | |
qpts_rcens <- rstanarm:::uapply(qp, unstandardise_qpts, 0, t_rcens) | |
qpts_icenl <- rstanarm:::uapply(qp, unstandardise_qpts, 0, t_icenl) | |
qpts_icenu <- rstanarm:::uapply(qp, unstandardise_qpts, 0, t_icenu) | |
qpts_delay <- rstanarm:::uapply(qp, unstandardise_qpts, 0, t_delay) | |
# quadrature weights, evaluated for each row of data | |
qwts_event <-rstanarm:::uapply(qw, unstandardise_qwts, 0, t_event) | |
qwts_lcens <-rstanarm:::uapply(qw, unstandardise_qwts, 0, t_lcens) | |
qwts_rcens <-rstanarm:::uapply(qw, unstandardise_qwts, 0, t_rcens) | |
qwts_icenl <-rstanarm:::uapply(qw, unstandardise_qwts, 0, t_icenl) | |
qwts_icenu <-rstanarm:::uapply(qw, unstandardise_qwts, 0, t_icenu) | |
qwts_delay <-rstanarm:::uapply(qw, unstandardise_qwts, 0, t_delay) | |
# times at events and all quadrature points | |
cpts_list <- list(t_event, | |
qpts_event, | |
qpts_lcens, | |
qpts_rcens, | |
qpts_icenl, | |
qpts_icenu, | |
qpts_delay) | |
idx_cpts <- rstanarm:::get_idx_array(sapply(cpts_list, length)) | |
cpts <- unlist(cpts_list) # as vector | |
# number of quadrature points | |
qevent <- length(qwts_event) | |
qlcens <- length(qwts_lcens) | |
qrcens <- length(qwts_rcens) | |
qicens <- length(qwts_icenl) | |
qdelay <- length(qwts_delay) | |
} else { | |
# times at all different event types | |
cpts_list <- list(t_event, | |
t_lcens, | |
t_rcens, | |
t_icenl, | |
t_icenu, | |
t_delay) | |
idx_cpts <- rstanarm:::get_idx_array(sapply(cpts_list, length)) | |
cpts <- unlist(cpts_list) # as vector | |
# dud entries for stan | |
qpts_event <- rep(0,0) | |
qpts_lcens <- rep(0,0) | |
qpts_rcens <- rep(0,0) | |
qpts_icenl <- rep(0,0) | |
qpts_icenu <- rep(0,0) | |
qpts_delay <- rep(0,0) | |
if (!qnodes == 15) # warn user if qnodes is not equal to the default | |
warning2("There is no quadrature required so 'qnodes' is being ignored.") | |
} | |
#----- basis terms for baseline hazard | |
if (!has_quadrature) { | |
basis_event <- rstanarm:::make_basis(t_event, basehaz) | |
ibasis_event <- rstanarm:::make_basis(t_event, basehaz, integrate = TRUE) | |
ibasis_lcens <- rstanarm:::make_basis(t_lcens, basehaz, integrate = TRUE) | |
ibasis_rcens <- rstanarm:::make_basis(t_rcens, basehaz, integrate = TRUE) | |
ibasis_icenl <- rstanarm:::make_basis(t_icenl, basehaz, integrate = TRUE) | |
ibasis_icenu <- rstanarm:::make_basis(t_icenu, basehaz, integrate = TRUE) | |
ibasis_delay <- rstanarm:::make_basis(t_delay, basehaz, integrate = TRUE) | |
} else { | |
basis_epts_event <- rstanarm:::make_basis(t_event, basehaz) | |
basis_qpts_event <- rstanarm:::make_basis(qpts_event, basehaz) | |
basis_qpts_lcens <- rstanarm:::make_basis(qpts_lcens, basehaz) | |
basis_qpts_rcens <- rstanarm:::make_basis(qpts_rcens, basehaz) | |
basis_qpts_icenl <- rstanarm:::make_basis(qpts_icenl, basehaz) | |
basis_qpts_icenu <- rstanarm:::make_basis(qpts_icenu, basehaz) | |
basis_qpts_delay <- rstanarm:::make_basis(qpts_delay, basehaz) | |
} | |
#----- model frames for generating predictor matrices | |
mf_event <-rstanarm:::keep_rows(mf, status == 1) | |
mf_lcens <-rstanarm:::keep_rows(mf, status == 2) | |
mf_rcens <-rstanarm:::keep_rows(mf, status == 0) | |
mf_icens <-rstanarm:::keep_rows(mf, status == 3) | |
mf_delay <-rstanarm:::keep_rows(mf, delayed) | |
if (!has_quadrature) { | |
# combined model frame, without quadrature | |
mf_cpts <- rbind(mf_event, | |
mf_lcens, | |
mf_rcens, | |
mf_icens, | |
mf_icens, | |
mf_delay) | |
} else { | |
# combined model frame, with quadrature | |
mf_cpts <- rbind(mf_event, | |
rstanarm:::rep_rows(mf_event, times = qnodes), | |
rstanarm:::rep_rows(mf_lcens, times = qnodes), | |
rstanarm:::rep_rows(mf_rcens, times = qnodes), | |
rstanarm:::rep_rows(mf_icens, times = qnodes), | |
rstanarm:::rep_rows(mf_icens, times = qnodes), | |
rstanarm:::rep_rows(mf_delay, times = qnodes)) | |
} | |
if (has_tve) { | |
# generate a model frame with time transformations for tve effects | |
mf_tve <- rstanarm:::make_model_frame(formula$tt_frame, data.frame(times__ = cpts))$mf | |
# NB next line avoids dropping terms attribute from 'mf_cpts' | |
mf_cpts[, colnames(mf_tve)] <- mf_tve | |
} | |
#----- time-fixed predictor matrices | |
ff <- formula$fe_form | |
x <- rstanarm:::make_x(ff, mf )$x | |
x_cpts <- rstanarm:::make_x(ff, mf_cpts)$x | |
x_centred <- sweep(x_cpts, 2, colMeans(x), FUN = "-") | |
K <- ncol(x_cpts) | |
if (!has_quadrature) { | |
# time-fixed predictor matrices, without quadrature | |
# NB skip index 5 on purpose, since time fixed predictor matrix is | |
# identical for lower and upper limits of interval censoring time | |
x_event <- x_centred[idx_cpts[1,1]:idx_cpts[1,2], , drop = FALSE] | |
x_lcens <- x_centred[idx_cpts[2,1]:idx_cpts[2,2], , drop = FALSE] | |
x_rcens <- x_centred[idx_cpts[3,1]:idx_cpts[3,2], , drop = FALSE] | |
x_icens <- x_centred[idx_cpts[4,1]:idx_cpts[4,2], , drop = FALSE] | |
x_delay <- x_centred[idx_cpts[6,1]:idx_cpts[6,2], , drop = FALSE] | |
} else { | |
# time-fixed predictor matrices, with quadrature | |
# NB skip index 6 on purpose, since time fixed predictor matrix is | |
# identical for lower and upper limits of interval censoring time | |
x_epts_event <- x_centred[idx_cpts[1,1]:idx_cpts[1,2], , drop = FALSE] | |
x_qpts_event <- x_centred[idx_cpts[2,1]:idx_cpts[2,2], , drop = FALSE] | |
x_qpts_lcens <- x_centred[idx_cpts[3,1]:idx_cpts[3,2], , drop = FALSE] | |
x_qpts_rcens <- x_centred[idx_cpts[4,1]:idx_cpts[4,2], , drop = FALSE] | |
x_qpts_icens <- x_centred[idx_cpts[5,1]:idx_cpts[5,2], , drop = FALSE] | |
x_qpts_delay <- x_centred[idx_cpts[7,1]:idx_cpts[7,2], , drop = FALSE] | |
} | |
#----- time-varying predictor matrices | |
if (has_tve) { | |
# time-varying predictor matrix | |
s_cpts <- rstanarm:::make_s(formula, mf_cpts, xlevs = xlevs) | |
smooth_map <- rstanarm:::get_smooth_name(s_cpts, type = "smooth_map") | |
smooth_idx <- rstanarm:::get_idx_array(table(smooth_map)) | |
S <- ncol(s_cpts) # number of tve coefficients | |
# store some additional information in model formula | |
# stating how many columns in the predictor matrix | |
# each tve() term in the model formula corresponds to | |
formula$tt_ncol <- attr(s_cpts, "tt_ncol") | |
formula$tt_map <- attr(s_cpts, "tt_map") | |
} else { | |
# dud entries if no tve() terms in model formula | |
s_cpts <- matrix(0,length(cpts),0) | |
smooth_idx <- matrix(0,0,2) | |
smooth_map <- integer(0) | |
S <- 0L | |
formula$tt_ncol <- integer(0) | |
formula$tt_map <- integer(0) | |
} | |
if (has_quadrature) { | |
# time-varying predictor matrices, with quadrature | |
s_epts_event <- s_cpts[idx_cpts[1,1]:idx_cpts[1,2], , drop = FALSE] | |
s_qpts_event <- s_cpts[idx_cpts[2,1]:idx_cpts[2,2], , drop = FALSE] | |
s_qpts_lcens <- s_cpts[idx_cpts[3,1]:idx_cpts[3,2], , drop = FALSE] | |
s_qpts_rcens <- s_cpts[idx_cpts[4,1]:idx_cpts[4,2], , drop = FALSE] | |
s_qpts_icenl <- s_cpts[idx_cpts[5,1]:idx_cpts[5,2], , drop = FALSE] | |
s_qpts_icenu <- s_cpts[idx_cpts[6,1]:idx_cpts[6,2], , drop = FALSE] | |
s_qpts_delay <- s_cpts[idx_cpts[7,1]:idx_cpts[7,2], , drop = FALSE] | |
} | |
#----- random effects predictor matrices | |
has_bars <- as.logical(length(formula$bars)) | |
# use 'stan_glmer' approach | |
if (has_bars) { | |
group_unpadded <- lme4::mkReTrms(formula$bars, mf_cpts) | |
group <- rstanarm:::pad_reTrms(Ztlist = group_unpadded$Ztlist, | |
cnms = group_unpadded$cnms, | |
flist = group_unpadded$flist) | |
z_cpts <- group$Z | |
} else { | |
group <- NULL | |
z_cpts <- matrix(0,length(cpts),0) | |
} | |
if (!has_quadrature) { | |
# random effects predictor matrices, without quadrature | |
# NB skip index 5 on purpose, since time fixed predictor matrix is | |
# identical for lower and upper limits of interval censoring time | |
z_event <- z_cpts[idx_cpts[1,1]:idx_cpts[1,2], , drop = FALSE] | |
z_lcens <- z_cpts[idx_cpts[2,1]:idx_cpts[2,2], , drop = FALSE] | |
z_rcens <- z_cpts[idx_cpts[3,1]:idx_cpts[3,2], , drop = FALSE] | |
z_icens <- z_cpts[idx_cpts[4,1]:idx_cpts[4,2], , drop = FALSE] | |
z_delay <- z_cpts[idx_cpts[6,1]:idx_cpts[6,2], , drop = FALSE] | |
parts_event <- extract_sparse_parts(z_event) | |
parts_lcens <- extract_sparse_parts(z_lcens) | |
parts_rcens <- extract_sparse_parts(z_rcens) | |
parts_icens <- extract_sparse_parts(z_icens) | |
parts_delay <- extract_sparse_parts(z_delay) | |
} else { | |
# random effects predictor matrices, with quadrature | |
# NB skip index 6 on purpose, since time fixed predictor matrix is | |
# identical for lower and upper limits of interval censoring time | |
z_epts_event <- z_cpts[idx_cpts[1,1]:idx_cpts[1,2], , drop = FALSE] | |
z_qpts_event <- z_cpts[idx_cpts[2,1]:idx_cpts[2,2], , drop = FALSE] | |
z_qpts_lcens <- z_cpts[idx_cpts[3,1]:idx_cpts[3,2], , drop = FALSE] | |
z_qpts_rcens <- z_cpts[idx_cpts[4,1]:idx_cpts[4,2], , drop = FALSE] | |
z_qpts_icens <- z_cpts[idx_cpts[5,1]:idx_cpts[5,2], , drop = FALSE] | |
z_qpts_delay <- z_cpts[idx_cpts[7,1]:idx_cpts[7,2], , drop = FALSE] | |
parts_epts_event <- rstanarm:::extract_sparse_parts(z_epts_event) | |
parts_qpts_event <- rstanarm:::extract_sparse_parts(z_qpts_event) | |
parts_qpts_lcens <- rstanarm:::extract_sparse_parts(z_qpts_lcens) | |
parts_qpts_rcens <- rstanarm:::extract_sparse_parts(z_qpts_rcens) | |
parts_qpts_icens <- rstanarm:::extract_sparse_parts(z_qpts_icens) | |
parts_qpts_delay <- rstanarm:::extract_sparse_parts(z_qpts_delay) | |
} | |
#----- stan data | |
standata <- rstanarm:::nlist( | |
K, S, | |
nvars, | |
x_bar = rstanarm:::aa(colMeans(x)), | |
has_intercept, | |
has_quadrature, | |
smooth_map, | |
smooth_idx, | |
type = basehaz$type, | |
log_crude_event_rate = | |
ifelse(is_aft, -log_crude_event_rate, log_crude_event_rate), | |
nevent = if (has_quadrature) 0L else nevent, | |
nlcens = if (has_quadrature) 0L else nlcens, | |
nrcens = if (has_quadrature) 0L else nrcens, | |
nicens = if (has_quadrature) 0L else nicens, | |
ndelay = if (has_quadrature) 0L else ndelay, | |
t_event = if (has_quadrature) rep(0,0) else t_event, | |
t_lcens = if (has_quadrature) rep(0,0) else t_lcens, | |
t_rcens = if (has_quadrature) rep(0,0) else t_rcens, | |
t_icenl = if (has_quadrature) rep(0,0) else t_icenl, | |
t_icenu = if (has_quadrature) rep(0,0) else t_icenu, | |
t_delay = if (has_quadrature) rep(0,0) else t_delay, | |
x_event = if (has_quadrature) matrix(0,0,K) else x_event, | |
x_lcens = if (has_quadrature) matrix(0,0,K) else x_lcens, | |
x_rcens = if (has_quadrature) matrix(0,0,K) else x_rcens, | |
x_icens = if (has_quadrature) matrix(0,0,K) else x_icens, | |
x_delay = if (has_quadrature) matrix(0,0,K) else x_delay, | |
w_event = if (has_quadrature || !has_bars || nevent == 0) double(0) else parts_event$w, | |
w_lcens = if (has_quadrature || !has_bars || nlcens == 0) double(0) else parts_lcens$w, | |
w_rcens = if (has_quadrature || !has_bars || nrcens == 0) double(0) else parts_rcens$w, | |
w_icens = if (has_quadrature || !has_bars || nicens == 0) double(0) else parts_icens$w, | |
w_delay = if (has_quadrature || !has_bars || ndelay == 0) double(0) else parts_delay$w, | |
v_event = if (has_quadrature || !has_bars || nevent == 0) integer(0) else parts_event$v - 1L, | |
v_lcens = if (has_quadrature || !has_bars || nlcens == 0) integer(0) else parts_lcens$v - 1L, | |
v_rcens = if (has_quadrature || !has_bars || nrcens == 0) integer(0) else parts_rcens$v - 1L, | |
v_icens = if (has_quadrature || !has_bars || nicens == 0) integer(0) else parts_icens$v - 1L, | |
v_delay = if (has_quadrature || !has_bars || ndelay == 0) integer(0) else parts_delay$v - 1L, | |
u_event = if (has_quadrature || !has_bars || nevent == 0) integer(0) else parts_event$u - 1L, | |
u_lcens = if (has_quadrature || !has_bars || nlcens == 0) integer(0) else parts_lcens$u - 1L, | |
u_rcens = if (has_quadrature || !has_bars || nrcens == 0) integer(0) else parts_rcens$u - 1L, | |
u_icens = if (has_quadrature || !has_bars || nicens == 0) integer(0) else parts_icens$u - 1L, | |
u_delay = if (has_quadrature || !has_bars || ndelay == 0) integer(0) else parts_delay$u - 1L, | |
nnz_event = if (has_quadrature || !has_bars || nevent == 0) 0L else length(parts_event$w), | |
nnz_lcens = if (has_quadrature || !has_bars || nlcens == 0) 0L else length(parts_lcens$w), | |
nnz_rcens = if (has_quadrature || !has_bars || nrcens == 0) 0L else length(parts_rcens$w), | |
nnz_icens = if (has_quadrature || !has_bars || nicens == 0) 0L else length(parts_icens$w), | |
nnz_delay = if (has_quadrature || !has_bars || ndelay == 0) 0L else length(parts_delay$w), | |
basis_event = if (has_quadrature) matrix(0,0,nvars) else basis_event, | |
ibasis_event = if (has_quadrature) matrix(0,0,nvars) else ibasis_event, | |
ibasis_lcens = if (has_quadrature) matrix(0,0,nvars) else ibasis_lcens, | |
ibasis_rcens = if (has_quadrature) matrix(0,0,nvars) else ibasis_rcens, | |
ibasis_icenl = if (has_quadrature) matrix(0,0,nvars) else ibasis_icenl, | |
ibasis_icenu = if (has_quadrature) matrix(0,0,nvars) else ibasis_icenu, | |
ibasis_delay = if (has_quadrature) matrix(0,0,nvars) else ibasis_delay, | |
qnodes = if (!has_quadrature) 0L else qnodes, | |
Nevent = if (!has_quadrature) 0L else nevent, | |
Nlcens = if (!has_quadrature) 0L else nlcens, | |
Nrcens = if (!has_quadrature) 0L else nrcens, | |
Nicens = if (!has_quadrature) 0L else nicens, | |
Ndelay = if (!has_quadrature) 0L else ndelay, | |
qevent = if (!has_quadrature) 0L else qevent, | |
qlcens = if (!has_quadrature) 0L else qlcens, | |
qrcens = if (!has_quadrature) 0L else qrcens, | |
qicens = if (!has_quadrature) 0L else qicens, | |
qdelay = if (!has_quadrature) 0L else qdelay, | |
epts_event = if (!has_quadrature) rep(0,0) else t_event, | |
qpts_event = if (!has_quadrature) rep(0,0) else qpts_event, | |
qpts_lcens = if (!has_quadrature) rep(0,0) else qpts_lcens, | |
qpts_rcens = if (!has_quadrature) rep(0,0) else qpts_rcens, | |
qpts_icenl = if (!has_quadrature) rep(0,0) else qpts_icenl, | |
qpts_icenu = if (!has_quadrature) rep(0,0) else qpts_icenu, | |
qpts_delay = if (!has_quadrature) rep(0,0) else qpts_delay, | |
qwts_event = if (!has_quadrature) rep(0,0) else qwts_event, | |
qwts_lcens = if (!has_quadrature) rep(0,0) else qwts_lcens, | |
qwts_rcens = if (!has_quadrature) rep(0,0) else qwts_rcens, | |
qwts_icenl = if (!has_quadrature) rep(0,0) else qwts_icenl, | |
qwts_icenu = if (!has_quadrature) rep(0,0) else qwts_icenu, | |
qwts_delay = if (!has_quadrature) rep(0,0) else qwts_delay, | |
x_epts_event = if (!has_quadrature) matrix(0,0,K) else x_epts_event, | |
x_qpts_event = if (!has_quadrature) matrix(0,0,K) else x_qpts_event, | |
x_qpts_lcens = if (!has_quadrature) matrix(0,0,K) else x_qpts_lcens, | |
x_qpts_rcens = if (!has_quadrature) matrix(0,0,K) else x_qpts_rcens, | |
x_qpts_icens = if (!has_quadrature) matrix(0,0,K) else x_qpts_icens, | |
x_qpts_delay = if (!has_quadrature) matrix(0,0,K) else x_qpts_delay, | |
s_epts_event = if (!has_quadrature) matrix(0,0,S) else s_epts_event, | |
s_qpts_event = if (!has_quadrature) matrix(0,0,S) else s_qpts_event, | |
s_qpts_lcens = if (!has_quadrature) matrix(0,0,S) else s_qpts_lcens, | |
s_qpts_rcens = if (!has_quadrature) matrix(0,0,S) else s_qpts_rcens, | |
s_qpts_icenl = if (!has_quadrature) matrix(0,0,S) else s_qpts_icenl, | |
s_qpts_icenu = if (!has_quadrature) matrix(0,0,S) else s_qpts_icenu, | |
s_qpts_delay = if (!has_quadrature) matrix(0,0,S) else s_qpts_delay, | |
w_epts_event = if (!has_quadrature || !has_bars || qevent == 0) double(0) else parts_epts_event$w, | |
w_qpts_event = if (!has_quadrature || !has_bars || qevent == 0) double(0) else parts_qpts_event$w, | |
w_qpts_lcens = if (!has_quadrature || !has_bars || qlcens == 0) double(0) else parts_qpts_lcens$w, | |
w_qpts_rcens = if (!has_quadrature || !has_bars || qrcens == 0) double(0) else parts_qpts_rcens$w, | |
w_qpts_icens = if (!has_quadrature || !has_bars || qicens == 0) double(0) else parts_qpts_icens$w, | |
w_qpts_delay = if (!has_quadrature || !has_bars || qdelay == 0) double(0) else parts_qpts_delay$w, | |
v_epts_event = if (!has_quadrature || !has_bars || qevent == 0) integer(0) else parts_epts_event$v - 1L, | |
v_qpts_event = if (!has_quadrature || !has_bars || qevent == 0) integer(0) else parts_qpts_event$v - 1L, | |
v_qpts_lcens = if (!has_quadrature || !has_bars || qlcens == 0) integer(0) else parts_qpts_lcens$v - 1L, | |
v_qpts_rcens = if (!has_quadrature || !has_bars || qrcens == 0) integer(0) else parts_qpts_rcens$v - 1L, | |
v_qpts_icens = if (!has_quadrature || !has_bars || qicens == 0) integer(0) else parts_qpts_icens$v - 1L, | |
v_qpts_delay = if (!has_quadrature || !has_bars || qdelay == 0) integer(0) else parts_qpts_delay$v - 1L, | |
u_epts_event = if (!has_quadrature || !has_bars || qevent == 0) integer(0) else parts_epts_event$u - 1L, | |
u_qpts_event = if (!has_quadrature || !has_bars || qevent == 0) integer(0) else parts_qpts_event$u - 1L, | |
u_qpts_lcens = if (!has_quadrature || !has_bars || qlcens == 0) integer(0) else parts_qpts_lcens$u - 1L, | |
u_qpts_rcens = if (!has_quadrature || !has_bars || qrcens == 0) integer(0) else parts_qpts_rcens$u - 1L, | |
u_qpts_icens = if (!has_quadrature || !has_bars || qicens == 0) integer(0) else parts_qpts_icens$u - 1L, | |
u_qpts_delay = if (!has_quadrature || !has_bars || qdelay == 0) integer(0) else parts_qpts_delay$u - 1L, | |
nnz_epts_event = if (!has_quadrature || !has_bars || qevent == 0) 0L else length(parts_epts_event$w), | |
nnz_qpts_event = if (!has_quadrature || !has_bars || qevent == 0) 0L else length(parts_qpts_event$w), | |
nnz_qpts_lcens = if (!has_quadrature || !has_bars || qlcens == 0) 0L else length(parts_qpts_lcens$w), | |
nnz_qpts_rcens = if (!has_quadrature || !has_bars || qrcens == 0) 0L else length(parts_qpts_rcens$w), | |
nnz_qpts_icens = if (!has_quadrature || !has_bars || qicens == 0) 0L else length(parts_qpts_icens$w), | |
nnz_qpts_delay = if (!has_quadrature || !has_bars || qdelay == 0) 0L else length(parts_qpts_delay$w), | |
basis_epts_event = if (!has_quadrature) matrix(0,0,nvars) else basis_epts_event, | |
basis_qpts_event = if (!has_quadrature) matrix(0,0,nvars) else basis_qpts_event, | |
basis_qpts_lcens = if (!has_quadrature) matrix(0,0,nvars) else basis_qpts_lcens, | |
basis_qpts_rcens = if (!has_quadrature) matrix(0,0,nvars) else basis_qpts_rcens, | |
basis_qpts_icenl = if (!has_quadrature) matrix(0,0,nvars) else basis_qpts_icenl, | |
basis_qpts_icenu = if (!has_quadrature) matrix(0,0,nvars) else basis_qpts_icenu, | |
basis_qpts_delay = if (!has_quadrature) matrix(0,0,nvars) else basis_qpts_delay | |
) | |
#----- random-effects structure | |
if (has_bars) { | |
fl <- group$flist | |
p <- sapply(group$cnms, FUN = length) | |
l <- sapply(attr(fl, "assign"), function(i) nlevels(fl[[i]])) | |
t <- length(l) | |
standata$p <- as.array(p) # num ranefs for each grouping factor | |
standata$l <- as.array(l) # num levels for each grouping factor | |
standata$t <- t # num of grouping factors | |
standata$q <- ncol(group$Z) # p * l | |
standata$special_case <- all(sapply(group$cnms, intercept_only)) | |
} else { # no random effects structure | |
standata$p <- integer(0) | |
standata$l <- integer(0) | |
standata$t <- 0L | |
standata$q <- 0L | |
standata$special_case <- 0L | |
} | |
#----- priors and hyperparameters | |
# valid priors | |
ok_dists <- nlist("normal", | |
student_t = "t", | |
"cauchy", | |
"hs", | |
"hs_plus", | |
"laplace", | |
"lasso") # disallow product normal | |
ok_intercept_dists <- ok_dists[1:3] | |
ok_aux_dists <- rstanarm:::get_ok_priors_for_aux(basehaz) | |
ok_smooth_dists <- c(ok_dists[1:3], "exponential") | |
ok_covariance_dists <- c("decov") | |
if (missing(prior_aux)) | |
prior_aux <- rstanarm:::get_default_prior_for_aux(basehaz) | |
# priors | |
user_prior_stuff <- prior_stuff <- | |
rstanarm:::handle_glm_prior(prior, | |
nvars = K, | |
default_scale = 2.5, | |
link = NULL, | |
ok_dists = ok_dists) | |
user_prior_intercept_stuff <- prior_intercept_stuff <- | |
rstanarm:::handle_glm_prior(prior_intercept, | |
nvars = 1, | |
default_scale = 20, | |
link = NULL, | |
ok_dists = ok_intercept_dists) | |
user_prior_aux_stuff <- prior_aux_stuff <- | |
rstanarm:::handle_glm_prior(prior_aux, | |
nvars = basehaz$nvars, | |
default_scale = rstanarm:::get_default_aux_scale(basehaz), | |
link = NULL, | |
ok_dists = ok_aux_dists) | |
user_prior_smooth_stuff <- prior_smooth_stuff <- | |
rstanarm:::handle_glm_prior(prior_smooth, | |
nvars = if (S) max(smooth_map) else 0, | |
default_scale = 1, | |
link = NULL, | |
ok_dists = ok_smooth_dists) | |
# stop null priors when prior_PD is true | |
if (prior_PD) { | |
if (is.null(prior)) | |
stop("'prior' cannot be NULL if 'prior_PD' is TRUE.") | |
if (is.null(prior_intercept) && has_intercept) | |
stop("'prior_intercept' cannot be NULL if 'prior_PD' is TRUE.") | |
if (is.null(prior_aux)) | |
stop("'prior_aux' cannot be NULL if 'prior_PD' is TRUE.") | |
if (is.null(prior_smooth) && (S > 0)) | |
stop("'prior_smooth' cannot be NULL if 'prior_PD' is TRUE.") | |
} | |
# handle prior for random effects structure | |
if (has_bars) { | |
user_prior_b_stuff <- prior_b_stuff <- | |
rstanarm:::handle_cov_prior(prior_covariance, | |
cnms = group$cnms, | |
ok_dists = ok_covariance_dists) | |
if (is.null(prior_covariance)) | |
stop("'prior_covariance' cannot be NULL.") | |
} else { | |
user_prior_b_stuff <- NULL | |
prior_b_stuff <- NULL | |
prior_covariance <- NULL | |
} | |
# autoscaling of priors | |
prior_stuff <- rstanarm:::autoscale_prior(prior_stuff, predictors = x) | |
prior_intercept_stuff <- rstanarm:::autoscale_prior(prior_intercept_stuff) | |
prior_aux_stuff <- rstanarm:::autoscale_prior(prior_aux_stuff) | |
prior_smooth_stuff <- rstanarm:::autoscale_prior(prior_smooth_stuff) | |
# priors | |
standata$prior_dist <- prior_stuff$prior_dist | |
standata$prior_dist_for_intercept<- prior_intercept_stuff$prior_dist | |
standata$prior_dist_for_aux <- prior_aux_stuff$prior_dist | |
standata$prior_dist_for_smooth <- prior_smooth_stuff$prior_dist | |
standata$prior_dist_for_cov <- prior_b_stuff$prior_dist | |
# hyperparameters | |
standata$prior_mean <- prior_stuff$prior_mean | |
standata$prior_scale <- prior_stuff$prior_scale | |
standata$prior_df <- prior_stuff$prior_df | |
standata$prior_mean_for_intercept <- c(prior_intercept_stuff$prior_mean) | |
standata$prior_scale_for_intercept<- c(prior_intercept_stuff$prior_scale) | |
standata$prior_df_for_intercept <- c(prior_intercept_stuff$prior_df) | |
standata$prior_scale_for_aux <- prior_aux_stuff$prior_scale | |
standata$prior_df_for_aux <- prior_aux_stuff$prior_df | |
standata$prior_conc_for_aux <- prior_aux_stuff$prior_concentration | |
standata$prior_mean_for_smooth <- prior_smooth_stuff$prior_mean | |
standata$prior_scale_for_smooth <- prior_smooth_stuff$prior_scale | |
standata$prior_df_for_smooth <- prior_smooth_stuff$prior_df | |
standata$global_prior_scale <- prior_stuff$global_prior_scale | |
standata$global_prior_df <- prior_stuff$global_prior_df | |
standata$slab_df <- prior_stuff$slab_df | |
standata$slab_scale <- prior_stuff$slab_scale | |
# hyperparameters for covariance | |
if (has_bars) { | |
standata$b_prior_shape <- prior_b_stuff$prior_shape | |
standata$b_prior_scale <- prior_b_stuff$prior_scale | |
standata$concentration <- prior_b_stuff$prior_concentration | |
standata$regularization <- prior_b_stuff$prior_regularization | |
standata$len_concentration <- length(standata$concentration) | |
standata$len_regularization <- length(standata$regularization) | |
standata$len_theta_L <- sum(choose(standata$p, 2), standata$p) | |
} else { # no random effects structure | |
standata$b_prior_shape <- rep(0, 0) | |
standata$b_prior_scale <- rep(0, 0) | |
standata$concentration <- rep(0, 0) | |
standata$regularization <- rep(0, 0) | |
standata$len_concentration <- 0L | |
standata$len_regularization <- 0L | |
standata$len_theta_L <- 0L | |
} | |
# any additional flags | |
standata$prior_PD <- rstanarm:::ai(prior_PD) | |
#--------------- | |
# Prior summary | |
#--------------- | |
prior_info <- rstanarm:::summarize_jm_prior( | |
user_priorEvent = user_prior_stuff, | |
user_priorEvent_intercept = user_prior_intercept_stuff, | |
user_priorEvent_aux = user_prior_aux_stuff, | |
adjusted_priorEvent_scale = prior_stuff$prior_scale, | |
adjusted_priorEvent_intercept_scale = prior_intercept_stuff$prior_scale, | |
adjusted_priorEvent_aux_scale = prior_aux_stuff$prior_scale, | |
e_has_intercept = has_intercept, | |
e_has_predictors = K > 0, | |
basehaz = basehaz, | |
user_prior_covariance = prior_covariance, | |
b_user_prior_stuff = user_prior_b_stuff, | |
b_prior_stuff = prior_b_stuff | |
) | |
#----------- | |
# Fit model | |
#----------- | |
# obtain stan model code | |
stanfit <- rstanarm:::stanmodels$surv | |
# specify parameters for stan to monitor | |
stanpars <- c(if (standata$has_intercept) "alpha", | |
if (standata$K) "beta", | |
if (standata$S) "beta_tve", | |
if (standata$S) "smooth_sd", | |
if (standata$nvars) "aux", | |
if (standata$t) "b", | |
if (standata$t) "theta_L") | |
# fit model using stan | |
if (algorithm == "sampling") { # mcmc | |
args <- rstanarm:::set_sampling_args( | |
object = stanfit, | |
data = standata, | |
pars = stanpars, | |
prior = prior, | |
user_dots = list(...), | |
user_adapt_delta = adapt_delta, | |
show_messages = FALSE) | |
args[["save_warmup"]] <- TRUE | |
stanfit <- do.call(rstan::sampling, args) | |
} else { # meanfield or fullrank vb | |
args <- rstanarm:::nlist( | |
object = stanfit, | |
data = standata, | |
pars = stanpars, | |
algorithm | |
) | |
args[names(dots)] <- dots | |
stanfit <- do.call(rstan::vb, args) | |
} | |
rstanarm:::check_stanfit(stanfit) | |
# replace 'theta_L' with the variance-covariance matrix | |
if (has_bars) | |
stanfit <- rstanarm:::evaluate_Sigma(stanfit, group$cnms) | |
# define new parameter names | |
nms_beta <- colnames(x_cpts) # may be NULL | |
nms_tve <- rstanarm:::get_smooth_name(s_cpts, type = "smooth_coefs") # may be NULL | |
nms_smooth <- rstanarm:::get_smooth_name(s_cpts, type = "smooth_sd") # may be NULL | |
nms_int <- rstanarm:::get_int_name_basehaz(basehaz) | |
nms_aux <- rstanarm:::get_aux_name_basehaz(basehaz) | |
nms_b <- rstanarm:::get_b_names(group) # may be NULL | |
nms_vc <- rstanarm:::get_varcov_names(group) # may be NULL | |
nms_all <- c(nms_int, | |
nms_beta, | |
nms_tve, | |
nms_smooth, | |
nms_aux, | |
nms_b, | |
nms_vc, | |
"log-posterior") | |
# substitute new parameter names into 'stanfit' object | |
stanfit <- rstanarm:::replace_stanfit_nms(stanfit, nms_all) | |
# return an object of class 'stansurv' | |
fit <- rstanarm:::nlist(stanfit, | |
formula, | |
has_tve, | |
has_quadrature, | |
has_bars, | |
data, | |
model_frame = mf, | |
terms = mt, | |
xlevels = .getXlevels(mt, mf), | |
x, | |
x_cpts, | |
s_cpts = if (has_tve) s_cpts else NULL, | |
z_cpts = if (has_bars) z_cpts else NULL, | |
cnms = if (has_bars) group_unpadded$cnms else NULL, | |
flist = if (has_bars) group_unpadded$flist else NULL, | |
t_beg, | |
t_end, | |
status, | |
event = as.logical(status == 1), | |
delayed, | |
basehaz, | |
nobs = nrow(mf), | |
nevents = nevent, | |
nlcens, | |
nrcens, | |
nicens, | |
ncensor = nlcens + nrcens + nicens, | |
ndelayed = ndelay, | |
prior_info, | |
qnodes = if (has_quadrature) qnodes else NULL, | |
algorithm, | |
stan_function = "stan_surv", | |
rstanarm_version = utils::packageVersion("rstanarm"), | |
call = match.call(expand.dots = TRUE)) | |
rstanarm:::stansurv(fit) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment