Last active
November 26, 2023 17:18
-
-
Save halflearned/bea4e5137c0c81fd18a75f682da466c8 to your computer and use it in GitHub Desktop.
Finding subgroups with different treatment effects via grf (K-fold cross-validation version --- time-consuming!)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Note this is mostly for didactic purposes. | |
# See https://bookdown.org/halflearned/ml-ci-tutorial/hte-i-binary-treatment.html. | |
# Read in data | |
data <- read.csv("https://docs.google.com/uc?id=1kSxrVci_EUcSr_Lg1JKk1l7Xd5I9zfRC&export=download") | |
# Treatment: does the the gov't spend too much on "welfare" (1) or "assistance to the poor" (0) | |
treatment <- "w" | |
# Outcome: 1 for 'yes', 0 for 'no' | |
outcome <- "y" | |
# Additional covariates | |
covariates <- c("age", "polviews", "income", "educ", "marital", "sex") | |
data <- data[1:1000,] | |
n <- nrow(data) | |
# Valid randomized data and observational data with unconfoundedness+overlap. | |
# Note: read the comments below carefully. | |
# In randomized settings, do not estimate forest.e and e.hat; use known assignment probs. | |
fmla <- formula(paste0("~ 0 + ", paste0(covariates, collapse="+"))) | |
X <- model.matrix(fmla, data) | |
W <- data[,treatment] | |
Y <- data[,outcome] | |
n.folds <- 5 | |
indices <- split(seq(n), sort(seq(n) %% n.folds)) | |
num.rankings <- 5 | |
res <- lapply(indices, function(idx) { | |
# Fit the outcome model on training subset, | |
# predict on held-out fold. | |
forest.m <- regression_forest(X[-idx,], Y[-idx]) | |
m.hat <- predict(forest.m, X[idx,])$predictions | |
# COMMENT / UNCOMMENT AS NEEDED: | |
# # If assignment probabilities are unknown and need to be estimated, | |
# # e.g., observational setting with unconfoundedness+overlap: | |
# forest.e <- regression_forest(X[-idx,], W[-idx], num.trees=1000) | |
# e.hat <- predict(forest.e, X[idx,])$predictions | |
# # If probabilities are known, | |
# # e.g., randomized setting with fixed assignment probabilities. | |
e.hat <- rep(0.5, length(idx)) | |
# Estimating CATE tau(X) for observations in held-out sample | |
forest.tau <- causal_forest(X[-idx,], Y[-idx], W[-idx], num.trees = 100) | |
tau.hat <- predict(forest.tau, X[idx,])$predictions | |
# Estimating mu.hat(X, 1) and mu.hat(X, 0) for obs in held-out sample | |
# Note: to understand this, read equations 6-8 in this vignette | |
# https://grf-labs.github.io/grf/articles/muhats.html | |
mu.hat.0 <- m.hat - e.hat * tau.hat # E[Y|X,W=0] = E[Y|X] - e(X)*tau(X) | |
mu.hat.1 <- m.hat + (1 - e.hat) * tau.hat # E[Y|X,W=1] = E[Y|X] + (1 - e(X))*tau(X) | |
# AIPW scores | |
aipw.scores <- (tau.hat | |
+ W[idx] / e.hat * (Y[idx] - mu.hat.1) | |
- (1 - W[idx]) / (1 - e.hat) * (Y[idx] - mu.hat.0)) | |
# Rank observations on held-out sample based on estimated CATE. | |
# Subtle but important: we must rank the observations based only on the hold-out sample, | |
# as opposed to computing all the predictions first and computing an overall ranking. | |
tau.hat.quantiles <- quantile(tau.hat, probs = seq(0, 1, length.out = num.rankings+1)) | |
ranking <- cut(tau.hat, tau.hat.quantiles, include.lowest = TRUE, labels = paste0("Q", seq(num.rankings))) | |
# Store results | |
data.frame(aipw.scores, tau.hat, ranking=factor(ranking), outcome=Y[idx], treatment=W[idx]) | |
}) | |
res <- do.call(rbind, res) | |
# Average AIPW scores for the treatment effect within each ranking | |
# Valid in randomized and observational settings with unconfoundedness+overlap. | |
forest.ate <- lm(aipw.scores ~ 0 + ranking, data=res) | |
forest.ate <- coeftest(forest.ate, vcov=vcovHC(forest.ate, type='HC2')) | |
forest.ate <- data.frame("aipw", paste0("Q", seq(num.rankings)), forest.ate[,1:2]) | |
colnames(forest.ate) <- c("method", "ranking", "estimate", "std.err") | |
rownames(forest.ate) <- NULL | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment