Skip to content

Instantly share code, notes, and snippets.

@nt-williams
Last active July 9, 2024 18:16
Show Gist options
  • Save nt-williams/702427d62cfe76f2efb69d240fdb3b2d to your computer and use it in GitHub Desktop.
Save nt-williams/702427d62cfe76f2efb69d240fdb3b2d to your computer and use it in GitHub Desktop.
tmle_mlr3
tmle_mlr3 <- function(data,
trt,
covar_trt,
covar_outcome,
covar_cens,
outcome,
id = NULL,
g_learners = "glm",
Q_learners = "glm",
c_learners = "glm",
outcome_type = c("binomial", "continuous"),
.mlr3superlearner_folds = 10,
.trim = 0.99) {
require("mlr3superlearner")
trim <- \(x, trim = .trim) pmin(x, quantile(x, trim))
folds <- 1
folded <- origami::make_folds(data, V = folds,
cluster_ids = {
if (is.null(id)) 1:nrow(data)
else data[[id]]
})
if (folds == 1) {
folded[[1]]$training_set <- folded[[1]]$validation_set
}
obs <- !is.na(data[[outcome]])
if (match.arg(outcome_type) == "continuous") {
bounds <- c(min(data[[outcome]], na.rm = T), max(data[[outcome]], na.rm = T))
data[[outcome]] <- (data[[outcome]] - bounds[1]) / (bounds[2] - bounds[1])
}
trt_factor <- as.data.table(model.matrix(~ factor(data[[trt]]) - 1))
lvls <- levels(factor(data[[trt]]))
names(trt_factor) <- paste0(trt, ".", lvls)
for (i in 1:folds) {
train <- as.data.table(data[folded[[i]]$training_set, ])
train_trt_factor <- as.data.table(trt_factor[folded[[i]]$training_set, ])
valid <- as.data.table(data[folded[[i]]$validation_set, ])
valid_trt_factor <- as.data.table(trt_factor[folded[[i]]$validation_set, ])
valids <- vector("list", length(lvls) + 1)
names(valids) <- c("A", lvls)
valids[["A"]] <- valid_trt_factor
for (lvl in lvls) {
current <- paste0(trt, ".", lvl)
other <- setdiff(paste0(trt, ".", lvls), current)
valids[[lvl]] <- data.table::copy(valid_trt_factor)
valids[[lvl]][[current]] <- 1
valids[[lvl]][, (other) := lapply(.SD, function(x) rep(0, length(x))), .SDcols = other]
}
Qs <- mlr3superlearner::mlr3superlearner(
data = cbind(train[, c(..covar_outcome, ..outcome, ..id)], train_trt_factor)[obs, ],
target = outcome,
library = Q_learners,
outcome_type = match.arg(outcome_type),
folds = .mlr3superlearner_folds,
newdata = lapply(valids, function(x) cbind(valid[, c(..covar_outcome, ..id)], x)),
group = id
)$preds
Cs <- matrix(nrow = nrow(data), ncol = 1, data = 1)
prob_observed <- matrix(nrow = nrow(data), ncol = 1, data = 1)
if (!all(obs)) {
prob_observed[folded[[i]]$validation_set, 1] <-
mlr3superlearner::mlr3superlearner(
data = cbind(train[, c(..covar_cens, ..id)],
data.table::data.table(tmp_cens = as.numeric(obs))[folded[[i]]$training_set, ]),
target = "tmp_cens",
library = c_learners,
outcome_type = "binomial",
folds = .mlr3superlearner_folds,
newdata = list(valid[, c(..covar_cens, ..id)]),
group = id
)$preds[[1]]
Cs[folded[[i]]$validation_set, 1] <-
as.numeric(obs[folded[[i]]$validation_set]) / prob_observed[folded[[i]]$validation_set, 1]
}
g <- matrix(nrow = nrow(data), ncol = length(lvls))
Hs <- matrix(nrow = nrow(data), ncol = length(lvls) + 1)
colnames(Hs) <- c("A", lvls)
colnames(g) <- c(lvls)
for (lvl in lvls[1:(length(lvls) - 1)]) {
target <- paste0(trt, ".", lvl)
g[folded[[i]]$validation_set, lvl] <-
mlr3superlearner::mlr3superlearner(
data = cbind(train[, c(..covar_trt, ..id)], train_trt_factor[, ..target]),
target = target,
library = g_learners,
outcome_type = "binomial",
folds = .mlr3superlearner_folds,
newdata = list(valid[, c(..covar_trt, ..id)]),
group = id
)$preds[[1]]
Hs[folded[[i]]$validation_set, lvl] <-
trim(valid_trt_factor[[target]] / g[folded[[i]]$validation_set, lvl])
}
g[folded[[i]]$validation_set, lvls[length(lvls)]] <- 1 - rowSums(g, na.rm = T)
Hs[folded[[i]]$validation_set, lvls[length(lvls)]] <-
trim(valid_trt_factor[[paste0(trt, ".", lvls[length(lvls)])]] / g[folded[[i]]$validation_set, lvls[length(lvls)]])
Hs[, "A"] <- purrr::imap_dbl(data[[trt]], function(x, i) Hs[i, as.character(x)])
# calculate tmle
Qeps <- matrix(nrow = nrow(data), ncol = length(lvls) + 1)
colnames(Qeps) <- c("A", lvls)
tmle_data <- data.frame(y = train[[outcome]],
Q_A = Qs[["A"]],
H_A = Hs[, "A"]*Cs[, 1])
fluc <- glm(y ~ -1 + offset(qlogis(Q_A)) + H_A, data = tmle_data[obs, ], family = binomial)
eps <- coef(fluc)
for (lvl in c("A", lvls)) {
Qeps[folded[[i]]$validation_set, lvl] <- plogis(qlogis(Qs[[lvl]]) + eps*Hs[, lvl]*Cs[, 1])
}
}
y <- ifelse(is.na(data[[outcome]]), -999, data[[outcome]])
psis <- apply(Qeps, 2, mean)
eics <- lapply(c("A", lvls), function(lvl) Hs[, lvl] * Cs[, 1] * (y - Qeps[, lvl]) + Qeps[, lvl] - psis[lvl])
names(eics) <- c("A", lvls)
if (match.arg(outcome_type) == "continuous") {
rescale_y_continuous <- function(scaled) {
(scaled*(bounds[2] - bounds[1])) + bounds[1]
}
eics <- lapply(eics, rescale_y_continuous)
psis <- sapply(psis, rescale_y_continuous)
}
if (is.null(id)) {
ids <- 1:nrow(data)
} else {
ids <- data[[id]]
}
ses <- lapply(c("A", lvls), function(a) {
clusters <- split(eics[[a]], ids)
j <- length(clusters)
sqrt(var(vapply(clusters, function(x) mean(x), 1)) / j)
})
names(ses) <- c("A", lvls)
list(psi = psis,
std.error = ses,
ic = eics,
g = g,
prob_observed = prob_observed)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment