Skip to content

Instantly share code, notes, and snippets.

@erikcs
Last active January 25, 2023 01:42
Show Gist options
  • Save erikcs/424b7ac84e7360d755760ca597d82d6e to your computer and use it in GitHub Desktop.
Save erikcs/424b7ac84e7360d755760ca597d82d6e to your computer and use it in GitHub Desktop.
# To use this file include
# source("best_linear_projection_overlap.R")
# after
# library(grf)
best_linear_projection <- function(forest,
A = NULL,
subset = NULL,
debiasing.weights = NULL,
compliance.score = NULL,
num.trees.for.weights = 500,
vcov.type = "HC3",
target.sample = c("all", "overlap")) {
target.sample <- match.arg(target.sample)
clusters <- if (length(forest$clusters) > 0) {
forest$clusters
} else {
1:NROW(forest$Y.orig)
}
observation.weight <- grf:::observation_weights(forest)
subset <- grf:::validate_subset(forest, subset)
subset.weights <- observation.weight[subset]
if (target.sample == "overlap") {
if (any(c("causal_forest", "causal_survival_forest") %in% class(forest))) {
overlap.weights <- forest$W.hat * (1 - forest$W.hat)
# some overlap weights might be exactly zero, these are currently not handled correctly in
# `sandwhich`'s SE calculation and we drop these units here.
subset <- intersect(subset, which(overlap.weights > .Machine$double.eps))
subset.weights <- observation.weight[subset] * overlap.weights[subset]
} else {
stop("option `target.sample=overlap` is not supported for this forest type.")
}
}
subset.clusters <- clusters[subset]
grf:::validate_sandwich(subset.weights)
if (length(unique(subset.clusters)) <= 1) {
stop("The specified subset must contain units from more than one cluster.")
}
if (!is.null(debiasing.weights)) {
if (length(debiasing.weights) == NROW(forest$Y.orig)) {
debiasing.weights <- debiasing.weights[subset]
} else if (length(debiasing.weights) != length(subset)) {
stop("If specified, debiasing.weights must be a vector of length n or the subset length.")
}
}
binary.W <- all(forest$W.orig %in% c(0, 1))
if (binary.W && target.sample != "overlap") {
if (min(forest$W.hat[subset]) <= 0.01 || max(forest$W.hat[subset]) >= 0.99) {
rng <- range(forest$W.hat[subset])
warning(paste0(
"Estimated treatment propensities take values between ",
round(rng[1], 3), " and ", round(rng[2], 3),
" and in particular get very close to 0 or 1.",
" (using `target.sample=overlap`, or `subset` to filter data as in",
" Crump, Hotz, Imbens, and Mitnik (Biometrika, 2009) may be helpful)"
), immediate. = TRUE)
}
}
if (any(c("causal_forest", "causal_survival_forest", "instrumental_forest") %in% class(forest))) {
DR.scores <- grf:::get_scores(forest, subset = subset, debiasing.weights = debiasing.weights,
compliance.score = compliance.score, num.trees.for.weights = num.trees.for.weights)
} else {
stop(paste0("`best_linear_projection` is only implemented for ",
"`causal_forest`, `causal_survival_forest`, and `instrumental_forest`"))
}
if (!is.null(A)) {
if (is.vector(A)) {
dim(A) <- c(length(A), 1L)
}
if (nrow(A) == NROW(forest$Y.orig)) {
A.subset <- A[subset, , drop = FALSE]
} else if (nrow(A) == length(subset)) {
A.subset <- A
} else {
stop("The number of rows of A does not match the number of training examples.")
}
if (is.null(colnames(A.subset))) {
colnames(A.subset) <- paste0("A", 1:ncol(A))
}
DF <- data.frame(target = DR.scores, A.subset)
} else {
DF <- data.frame(target = DR.scores)
}
blp.ols <- lm(target ~ ., weights = subset.weights, data = DF)
blp.summary <- lmtest::coeftest(blp.ols,
vcov = sandwich::vcovCL,
type = vcov.type,
cluster = subset.clusters
)
attr(blp.summary, "method") <-
paste0("Best linear projection of the conditional average treatment effect.\n",
"Confidence intervals are cluster- and heteroskedasticity-robust ",
"(", vcov.type, ")")
blp.summary
}
if (FALSE) {
library(grf)
n <- 1000
p <- 5
X <- matrix(rnorm(n * (p)), n, p)
eX <- 1 / (1 + exp(-3 * X[, 1]))
W <- rbinom(n, 1, eX)
TAU <- X[, 2]
Y <- X[, 1] + (W - 0.5) * TAU + rnorm(n)
true.blp <- lm(TAU ~ X, weights = eX * (1 - eX))
cf <- causal_forest(X, Y, W)
best_linear_projection(cf, X, target.sample = "overlap")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment