Skip to content

Instantly share code, notes, and snippets.

@erikcs
Created April 23, 2024 06:09
Show Gist options
  • Save erikcs/cb8325fefe8bdfad6fc230015ddfe9cb to your computer and use it in GitHub Desktop.
Save erikcs/cb8325fefe8bdfad6fc230015ddfe9cb to your computer and use it in GitHub Desktop.
causal_survival_forest.custom <- function(
X, Y, W, D,
W.hat = NULL,
target = c("RMST", "survival.probability"),
horizon = NULL,
failure.times = NULL,
num.trees = 2000,
sample.weights = NULL,
clusters = NULL,
equalize.cluster.weights = FALSE,
sample.fraction = 0.5,
mtry = min(ceiling(sqrt(ncol(X)) + 20), ncol(X)),
min.node.size = 5,
honesty = TRUE,
honesty.fraction = 0.5,
honesty.prune.leaves = TRUE,
alpha = 0.05,
imbalance.penalty = 0,
stabilize.splits = TRUE,
ci.group.size = 2,
tune.parameters = "none",
compute.oob.predictions = TRUE,
num.threads = NULL,
seed = runif(1, 0, .Machine$integer.max)) {
target <- match.arg(target)
if (is.null(horizon) || !is.numeric(horizon) || length(horizon) != 1) {
stop("The `horizon` argument defining the estimand is required.")
}
has.missing.values <- grf:::validate_X(X, allow.na = TRUE)
grf:::validate_sample_weights(sample.weights, X)
Y <- grf:::validate_observations(Y, X)
W <- grf:::validate_observations(W, X)
D <- grf:::validate_observations(D, X)
clusters <- grf:::validate_clusters(clusters, X)
samples.per.cluster <- grf:::validate_equalize_cluster_weights(equalize.cluster.weights, clusters, sample.weights)
num.threads <- grf:::validate_num_threads(num.threads)
if (any(Y < 0)) {
stop("The event times must be non-negative.")
}
if (!all(D %in% c(0, 1))) {
stop("The censor values can only be 0 or 1.")
}
if (sum(D) == 0) {
stop("All observations are censored.")
}
if (target == "RMST") {
# f(T) <- min(T, horizon)
D[Y >= horizon] <- 1
Y[Y >= horizon] <- horizon
fY <- Y
} else {
# f(T) <- 1{T > horizon}
fY <- as.numeric(Y > horizon)
}
if (is.null(failure.times)) {
Y.grid <- sort(unique(Y))
} else if (min(Y) < min(failure.times)) {
stop("If provided, `failure.times` should be a grid starting on or before min(Y).")
} else {
Y.grid <- failure.times
}
if (length(Y.grid) <= 2) {
stop("The number of distinct event times should be more than 2.")
}
if (horizon < min(Y.grid)) {
stop("`horizon` cannot be before the first event.")
}
if (nrow(X) > 5000 && length(Y.grid) / nrow(X) > 0.1) {
warning(paste0("The number of events are more than 10% of the sample size. ",
"To reduce the computational burden of fitting survival and ",
"censoring curves, consider discretizing the event values `Y` or ",
"supplying a coarser grid with the `failure.times` argument. "), immediate. = TRUE)
}
if (is.null(W.hat)) {
forest.W <- grf::regression_forest(X, W, num.trees = max(50, num.trees / 4),
sample.weights = sample.weights, clusters = clusters,
equalize.cluster.weights = equalize.cluster.weights,
sample.fraction = sample.fraction, mtry = mtry,
min.node.size = 5, honesty = TRUE,
honesty.fraction = 0.5, honesty.prune.leaves = TRUE,
alpha = alpha, imbalance.penalty = imbalance.penalty,
ci.group.size = 1, tune.parameters = tune.parameters,
compute.oob.predictions = TRUE,
num.threads = num.threads, seed = seed)
W.hat <- predict(forest.W)$predictions
} else if (length(W.hat) == 1) {
W.hat <- rep(W.hat, nrow(X))
} else if (length(W.hat) != nrow(X)) {
stop("W.hat has incorrect length.")
}
W.centered <- W - W.hat
args.nuisance <- list(failure.times = failure.times,
num.trees = max(50, min(num.trees / 4, 500)),
sample.weights = sample.weights,
clusters = clusters,
equalize.cluster.weights = equalize.cluster.weights,
sample.fraction = sample.fraction,
mtry = mtry,
min.node.size = 15,
honesty = TRUE,
honesty.fraction = 0.5,
honesty.prune.leaves = TRUE,
alpha = alpha,
prediction.type = "Nelson-Aalen", # to guarantee non-zero estimates.
compute.oob.predictions = TRUE,
num.threads = num.threads,
seed = seed)
# Compute survival-based nuisance components (https://arxiv.org/abs/2001.09887)
# m(x) relies on the survival function conditional on only X, while Q(x) relies on the conditioning (X, W).
# Instead of fitting two separate survival forests, we can use the forest fit on (X, W) to compute m(X)
# using the identity
# E[f(T) | X] = e(X) E[f(T) | X, W = 1] + (1 - e(X)) E[f(T) | X, W = 0]
# (for this to work W has to be binary).
# TODO bcjaeger 1): you can comment out the below if you want to use another survival function estimator
# if it is easier, another way to do the below could have been to use two survival forests:
# first fit forest1 = survival_forest(X,Y,D) then use that to estimate Y.hat (for RMST=expected_survival(forest1.predictions, forest1.grid))
# then fit forest2 = survival_forest(cbind(X, W),Y,D) and use that to estimate S.hat (=predict(forest2))
sf.survival <- do.call(grf::survival_forest, c(list(X = cbind(X, W), Y = Y, D = D), args.nuisance))
binary.W <- all(W %in% c(0, 1))
if (binary.W) {
# The survival function conditioning on being treated S(t, x, 1) estimated with an "S-learner".
# Computing OOB estimates for modified training samples is not a workflow we have implemented,
# so we do it with a manual workaround here (deleting/re-inserting precomputed predictions)
.predictions <- sf.survival[["predictions"]]
sf.survival[["predictions"]] <- NULL
sf.survival[["X.orig"]][, ncol(X) + 1] <- rep(1, nrow(X))
S1.hat <- predict(sf.survival, num.threads = num.threads)$predictions
# The survival function conditioning on being a control unit S(t, x, 0) estimated with an "S-learner".
sf.survival[["X.orig"]][, ncol(X) + 1] <- rep(0, nrow(X))
S0.hat <- predict(sf.survival, num.threads = num.threads)$predictions
sf.survival[["X.orig"]][, ncol(X) + 1] <- W
sf.survival[["predictions"]] <- .predictions
if (target == "RMST") {
# TODO bcjaeger 2): remove the above "sf.survival" usage and replace Y.hat with your n-length estimates of m(X) (RMST)
Y.hat <- W.hat * grf:::expected_survival(S1.hat, sf.survival$failure.times) +
(1 - W.hat) * grf:::expected_survival(S0.hat, sf.survival$failure.times)
} else {
horizonS.index <- findInterval(horizon, sf.survival$failure.times)
if (horizonS.index == 0) {
Y.hat <- rep(1, nrow(X))
} else {
# TODO bcjaeger 3): replace Y.hat with your n-length estimates of m(X) (Survival probability)
Y.hat <- W.hat * S1.hat[, horizonS.index] + (1 - W.hat) * S0.hat[, horizonS.index]
}
}
} else {
# Ignoring this code branch for the simplicity's sake
stop("Custom survival models + continuous treatment not implemented")
# If continuous W fit a separate survival forest to estimate E[f(T) | X].
# sf.Y <- do.call(grf::survival_forest, c(list(X = X, Y = Y, D = D), args.nuisance))
# SY.hat <- predict(sf.Y)$predictions
# if (target == "RMST") {
# Y.hat <- expected_survival(SY.hat, sf.Y$failure.times)
# } else {
# horizonS.index <- findInterval(horizon, sf.survival$failure.times)
# if (horizonS.index == 0) {
# Y.hat <- rep(1, nrow(X))
# } else {
# Y.hat <- SY.hat[, horizonS.index]
# }
# }
}
# The conditional survival function S(t, x, w) used to construct Q(x).
# TODO bcjaeger 4): replace S.hat with your predictions of the n * length(Y.grid)-sized
# matrix of survival curve estimates on the time grid "Y.grid".
S.hat <- predict(sf.survival, failure.times = Y.grid)$predictions
if (!identical(dim(S.hat), c(length(Y), length(Y.grid)))) stop("Wrong S.hat prediction dims")
# The conditional survival function for the censoring process S_C(t, x, w).
# TODO bcjaeger 5): replace C.hat with your estimates of the censoring process matrix (same grid as above)
sf.censor <- do.call(grf::survival_forest, c(list(X = cbind(X, W), Y = Y, D = 1 - D), args.nuisance))
C.hat <- predict(sf.censor, failure.times = Y.grid)$predictions
if (!identical(dim(C.hat), c(length(Y), length(Y.grid)))) stop("Wrong C.hat prediction dims")
if (target == "survival.probability") {
# Evaluate psi up to horizon
D[Y > horizon] <- 1
Y[Y > horizon] <- horizon
}
Y.index <- findInterval(Y, Y.grid) # (invariance: Y.index > 0)
C.Y.hat <- C.hat[cbind(seq_along(Y.index), Y.index)] # Pick out P[Ci > Yi | Xi, Wi]
if (target == "RMST" && any(C.Y.hat <= 0.05)) {
warning(paste("Estimated censoring probabilities go as low as:", round(min(C.Y.hat), 5),
"- an identifying assumption is that there exists a fixed positive constant M",
"such that the probability of observing an event past the maximum follow-up time ",
"is at least M (i.e. P(T > horizon | X) > M).",
"This warning appears when M is less than 0.05, at which point causal survival forest",
"can not be expected to deliver reliable estimates."), immediate. = TRUE)
} else if (target == "RMST" && any(C.Y.hat < 0.2)) {
warning(paste("Estimated censoring probabilities are lower than 0.2",
"- an identifying assumption is that there exists a fixed positive constant M",
"such that the probability of observing an event past the maximum follow-up time ",
"is at least M (i.e. P(T > horizon | X) > M)."))
} else if (target == "survival.probability" && any(C.Y.hat <= 0.001)) {
warning(paste("Estimated censoring probabilities go as low as:", round(min(C.Y.hat), 5),
"- forest estimates will likely be very unstable, a larger target `horizon`",
"is recommended."), immediate. = TRUE)
} else if (target == "survival.probability" && any(C.Y.hat < 0.05)) {
warning(paste("Estimated censoring probabilities are lower than 0.05",
"and forest estimates may not be stable. Using a smaller target `horizon`",
"may help."))
}
psi <- grf:::compute_psi(S.hat, C.hat, C.Y.hat, Y.hat, W.centered,
D, fY, Y.index, Y.grid, target, horizon)
grf:::validate_observations(psi[["numerator"]], X)
grf:::validate_observations(psi[["denominator"]], X)
data <- grf:::create_train_matrices(X,
treatment = W.centered,
survival.numerator = psi[["numerator"]],
survival.denominator = psi[["denominator"]],
censor = D,
sample.weights = sample.weights)
args <- list(num.trees = num.trees,
clusters = clusters,
samples.per.cluster = samples.per.cluster,
sample.fraction = sample.fraction,
mtry = mtry,
min.node.size = min.node.size,
honesty = honesty,
honesty.fraction = honesty.fraction,
honesty.prune.leaves = honesty.prune.leaves,
alpha = alpha,
imbalance.penalty = imbalance.penalty,
stabilize.splits = stabilize.splits,
ci.group.size = ci.group.size,
compute.oob.predictions = compute.oob.predictions,
num.threads = num.threads,
seed = seed)
forest <- grf:::do.call.rcpp(grf:::causal_survival_train, c(data, args))
class(forest) <- c("causal_survival_forest", "grf")
forest[["seed"]] <- seed
forest[["_psi"]] <- psi
forest[["X.orig"]] <- X
forest[["Y.orig"]] <- Y
forest[["W.orig"]] <- W
forest[["D.orig"]] <- D
forest[["Y.hat"]] <- Y.hat
forest[["W.hat"]] <- W.hat
forest[["sample.weights"]] <- sample.weights
forest[["clusters"]] <- clusters
forest[["equalize.cluster.weights"]] <- equalize.cluster.weights
forest[["has.missing.values"]] <- has.missing.values
forest[["target"]] <- target
forest[["horizon"]] <- horizon
forest
}
if (FALSE) {
n <- 500
p <- 5
X <- matrix(runif(n * p), n, p)
W <- rbinom(n, 1, 0.5)
horizon <- 1
failure.time <- pmin(rexp(n) * X[, 1] + W, horizon)
censor.time <- 2 * runif(n)
Y <- round(pmin(failure.time, censor.time), 2)
D <- as.integer(failure.time <= censor.time)
# grf causal survival forest
csf.orig <- grf::causal_survival_forest(X, Y, W, D, horizon = horizon, seed = 42)
grf::average_treatment_effect(csf.orig)
head(predict(csf.orig))
# your custom CS forest
csf.custom <- causal_survival_forest.custom(X, Y, W, D, horizon = horizon, seed = 42)
grf::average_treatment_effect(csf.custom)
head(predict(csf.custom))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment