Last active
July 9, 2024 18:16
-
-
Save nt-williams/702427d62cfe76f2efb69d240fdb3b2d to your computer and use it in GitHub Desktop.
tmle_mlr3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
tmle_mlr3 <- function(data, | |
trt, | |
covar_trt, | |
covar_outcome, | |
covar_cens, | |
outcome, | |
id = NULL, | |
g_learners = "glm", | |
Q_learners = "glm", | |
c_learners = "glm", | |
outcome_type = c("binomial", "continuous"), | |
.mlr3superlearner_folds = 10, | |
.trim = 0.99) { | |
require("mlr3superlearner") | |
trim <- \(x, trim = .trim) pmin(x, quantile(x, trim)) | |
folds <- 1 | |
folded <- origami::make_folds(data, V = folds, | |
cluster_ids = { | |
if (is.null(id)) 1:nrow(data) | |
else data[[id]] | |
}) | |
if (folds == 1) { | |
folded[[1]]$training_set <- folded[[1]]$validation_set | |
} | |
obs <- !is.na(data[[outcome]]) | |
if (match.arg(outcome_type) == "continuous") { | |
bounds <- c(min(data[[outcome]], na.rm = T), max(data[[outcome]], na.rm = T)) | |
data[[outcome]] <- (data[[outcome]] - bounds[1]) / (bounds[2] - bounds[1]) | |
} | |
trt_factor <- as.data.table(model.matrix(~ factor(data[[trt]]) - 1)) | |
lvls <- levels(factor(data[[trt]])) | |
names(trt_factor) <- paste0(trt, ".", lvls) | |
for (i in 1:folds) { | |
train <- as.data.table(data[folded[[i]]$training_set, ]) | |
train_trt_factor <- as.data.table(trt_factor[folded[[i]]$training_set, ]) | |
valid <- as.data.table(data[folded[[i]]$validation_set, ]) | |
valid_trt_factor <- as.data.table(trt_factor[folded[[i]]$validation_set, ]) | |
valids <- vector("list", length(lvls) + 1) | |
names(valids) <- c("A", lvls) | |
valids[["A"]] <- valid_trt_factor | |
for (lvl in lvls) { | |
current <- paste0(trt, ".", lvl) | |
other <- setdiff(paste0(trt, ".", lvls), current) | |
valids[[lvl]] <- data.table::copy(valid_trt_factor) | |
valids[[lvl]][[current]] <- 1 | |
valids[[lvl]][, (other) := lapply(.SD, function(x) rep(0, length(x))), .SDcols = other] | |
} | |
Qs <- mlr3superlearner::mlr3superlearner( | |
data = cbind(train[, c(..covar_outcome, ..outcome, ..id)], train_trt_factor)[obs, ], | |
target = outcome, | |
library = Q_learners, | |
outcome_type = match.arg(outcome_type), | |
folds = .mlr3superlearner_folds, | |
newdata = lapply(valids, function(x) cbind(valid[, c(..covar_outcome, ..id)], x)), | |
group = id | |
)$preds | |
Cs <- matrix(nrow = nrow(data), ncol = 1, data = 1) | |
prob_observed <- matrix(nrow = nrow(data), ncol = 1, data = 1) | |
if (!all(obs)) { | |
prob_observed[folded[[i]]$validation_set, 1] <- | |
mlr3superlearner::mlr3superlearner( | |
data = cbind(train[, c(..covar_cens, ..id)], | |
data.table::data.table(tmp_cens = as.numeric(obs))[folded[[i]]$training_set, ]), | |
target = "tmp_cens", | |
library = c_learners, | |
outcome_type = "binomial", | |
folds = .mlr3superlearner_folds, | |
newdata = list(valid[, c(..covar_cens, ..id)]), | |
group = id | |
)$preds[[1]] | |
Cs[folded[[i]]$validation_set, 1] <- | |
as.numeric(obs[folded[[i]]$validation_set]) / prob_observed[folded[[i]]$validation_set, 1] | |
} | |
g <- matrix(nrow = nrow(data), ncol = length(lvls)) | |
Hs <- matrix(nrow = nrow(data), ncol = length(lvls) + 1) | |
colnames(Hs) <- c("A", lvls) | |
colnames(g) <- c(lvls) | |
for (lvl in lvls[1:(length(lvls) - 1)]) { | |
target <- paste0(trt, ".", lvl) | |
g[folded[[i]]$validation_set, lvl] <- | |
mlr3superlearner::mlr3superlearner( | |
data = cbind(train[, c(..covar_trt, ..id)], train_trt_factor[, ..target]), | |
target = target, | |
library = g_learners, | |
outcome_type = "binomial", | |
folds = .mlr3superlearner_folds, | |
newdata = list(valid[, c(..covar_trt, ..id)]), | |
group = id | |
)$preds[[1]] | |
Hs[folded[[i]]$validation_set, lvl] <- | |
trim(valid_trt_factor[[target]] / g[folded[[i]]$validation_set, lvl]) | |
} | |
g[folded[[i]]$validation_set, lvls[length(lvls)]] <- 1 - rowSums(g, na.rm = T) | |
Hs[folded[[i]]$validation_set, lvls[length(lvls)]] <- | |
trim(valid_trt_factor[[paste0(trt, ".", lvls[length(lvls)])]] / g[folded[[i]]$validation_set, lvls[length(lvls)]]) | |
Hs[, "A"] <- purrr::imap_dbl(data[[trt]], function(x, i) Hs[i, as.character(x)]) | |
# calculate tmle | |
Qeps <- matrix(nrow = nrow(data), ncol = length(lvls) + 1) | |
colnames(Qeps) <- c("A", lvls) | |
tmle_data <- data.frame(y = train[[outcome]], | |
Q_A = Qs[["A"]], | |
H_A = Hs[, "A"]*Cs[, 1]) | |
fluc <- glm(y ~ -1 + offset(qlogis(Q_A)) + H_A, data = tmle_data[obs, ], family = binomial) | |
eps <- coef(fluc) | |
for (lvl in c("A", lvls)) { | |
Qeps[folded[[i]]$validation_set, lvl] <- plogis(qlogis(Qs[[lvl]]) + eps*Hs[, lvl]*Cs[, 1]) | |
} | |
} | |
y <- ifelse(is.na(data[[outcome]]), -999, data[[outcome]]) | |
psis <- apply(Qeps, 2, mean) | |
eics <- lapply(c("A", lvls), function(lvl) Hs[, lvl] * Cs[, 1] * (y - Qeps[, lvl]) + Qeps[, lvl] - psis[lvl]) | |
names(eics) <- c("A", lvls) | |
if (match.arg(outcome_type) == "continuous") { | |
rescale_y_continuous <- function(scaled) { | |
(scaled*(bounds[2] - bounds[1])) + bounds[1] | |
} | |
eics <- lapply(eics, rescale_y_continuous) | |
psis <- sapply(psis, rescale_y_continuous) | |
} | |
if (is.null(id)) { | |
ids <- 1:nrow(data) | |
} else { | |
ids <- data[[id]] | |
} | |
ses <- lapply(c("A", lvls), function(a) { | |
clusters <- split(eics[[a]], ids) | |
j <- length(clusters) | |
sqrt(var(vapply(clusters, function(x) mean(x), 1)) / j) | |
}) | |
names(ses) <- c("A", lvls) | |
list(psi = psis, | |
std.error = ses, | |
ic = eics, | |
g = g, | |
prob_observed = prob_observed) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment