Skip to content

Instantly share code, notes, and snippets.

@herbps10
Created May 6, 2020 17:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save herbps10/a7c03b1898a02c9b187a17bcc007d763 to your computer and use it in GitHub Desktop.
Save herbps10/a7c03b1898a02c9b187a17bcc007d763 to your computer and use it in GitHub Desktop.
Optimal treatment effects with resource constraints
library(tidyverse)
library(SuperLearner)
#
# Data generating process
#
generate_data <- function(seed, N) {
set.seed(seed)
tibble(
W1 = rnorm(N),
W2 = rnorm(N),
W3 = rnorm(N),
W4 = rnorm(N),
A = rbinom(N, 1, 0.5),
H = rbinom(N, 1, 0.5),
Y = ifelse(H == 0,
rbinom(N, 1, boot::inv.logit(1 - W1^2 + 3*W2 + A * (5 * W3^2 - 4.45))),
rbinom(N, 1, boot::inv.logit(-0.5 - W3 + 2 * W1 * W2 + A * (3 * abs(W2) - 1.5)))
)
)
}
#
# Plugin, one-step, and TMLE estimators
#
estimators <- function(data, kappa, SL.library) {
mu_model <- SuperLearner(data$Y, select(data, A, W1, W2, W3, W4), SL.library = SL.library, family = "binomial")
mu_n_a_w <- predict(mu_model)$pred
mu_n_1 <- predict(mu_model, newdata = select(data, W1, W2, W3, W4) %>% mutate(A = 1))$pred
mu_n_0 <- predict(mu_model, newdata = select(data, W1, W2, W3, W4) %>% mutate(A = 0))$pred
Y_tilde <- (2*data$A - 1) / (0.5) * (data$Y - mean(data$Y)) + mean(data$Y)
qbar_model <- SuperLearner(Y_tilde, select(data, W1, W2, W3, W4), SL.library = SL.library, family = "gaussian")
qbar_n <- predict(qbar_model)$pred
tau_n <- max(0, quantile(qbar_n, 1 - kappa))
d_n <- as.numeric(qbar_n > tau_n)
#
# Plug-in estimator
#
newdata = select(data, starts_with("W")) %>%
mutate(A = d_n)
psi_plugin <- mean(predict(mu_model, newdata)$pred)
#
# One-step estimator
#
# estimate of influence function
if_n <- as.numeric(data$A == d_n) / 0.5 * (data$Y - mu_n_a_w) +
mu_n_a_w - tau_n * (mean(d_n) - kappa) - psi_plugin
# add the empirical mean of the influence function to the plugin estimator
psi_onestep <- psi_plugin + mean(if_n)
psi_onestep_ci <- psi_onestep + c(-1, 1) * qnorm(0.975) * sd(if_n) / sqrt(length(if_n))
#
# TMLE
#
# clever covariate
H_a_w <- as.numeric(data$A == d_n) / 0.5
H_0 <- as.numeric(0 == d_n) / 0.5
H_1 <- as.numeric(1 == d_n) / 0.5
offset <- boot::logit(mu_n_a_w)
logistic_regression <- glm(data$Y ~ -1 + H_a_w + offset(offset), family = binomial(link = "logit"))
epsilon_n <- coef(logistic_regression)[1]
mu_n_a_w_epsilon <- boot::inv.logit(boot::logit(mu_n_a_w) + epsilon_n * H_a_w)
mu_n_0_epsilon <- boot::inv.logit(boot::logit(mu_n_0) + epsilon_n * H_0)
mu_n_1_epsilon <- boot::inv.logit(boot::logit(mu_n_1) + epsilon_n * H_1)
psi_tmle <- mean(mu_n_0_epsilon * as.numeric(d_n == 0) + mu_n_1_epsilon * as.numeric(d_n == 1))
if_n_epsilon <- as.numeric(data$A == d_n) / 0.5 * (data$Y - mu_n_a_w_epsilon) +
mu_n_a_w_epsilon - tau_n * (mean(d_n) - kappa) - psi_tmle
psi_tmle_ci <- psi_tmle + c(-1, 1) * qnorm(0.975) * sd(if_n) / sqrt(length(if_n_epsilon))
tribble(
~method, ~psi, ~ci,
"plugin", psi_plugin, NULL,
"onestep", psi_onestep, psi_onestep_ci,
"tmle", psi_tmle, psi_tmle_ci
)
}
#
# Execute simulation study
#
SL.library <- c("SL.glm", "SL.glm.interaction", "SL.step", "SL.mean", "SL.step.interaction", "SL.step.forward")
simulations <- 250
psi0 <- 0.49
Ns <- c(50, 500, 1000, 5000)
pb <- progress::progress_bar$new(total = simulations * length(Ns))
simulation_study <- expand_grid(
seed = 1:simulations,
N = Ns
) %>%
mutate(
data = map2(seed, N, generate_data),
psi = map(data, function(data) {
pb$tick()
estimators(data, kappa, SL.library)
})
)
# Calculate coverage
covered <- function(ci, x) ci[1] <= x && ci[2] >= x
simulation_study_coverage <- simulation_study %>%
select(-data) %>%
unnest(psi) %>%
mutate(covered = map_lgl(ci, covered, psi0)) %>%
filter(!is.na(covered)) %>%
group_by(N, method) %>%
summarize(coverage = mean(covered))
# Coverage Plot (Figure 1)
ggplot(simulation_study_coverage, aes(x = factor(N), y = coverage, color = method)) +
geom_point(size = 2) +
geom_hline(yintercept = 0.95, lty = 2) +
labs(x = "N", y = "95% CI coverage") +
cowplot::theme_cowplot()
# Mean absolute bias and standard errors (Table 1)
simulation_study %>% select(-data) %>% unnest(psi) %>%
group_by(N, method) %>%
summarize(abs_bias = mean(abs(psi0 - psi)),
se = sd(psi)) %>%
mutate_at(vars(abs_bias, se), signif, 2) %>%
knitr::kable(format = "latex")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment