Skip to content

Instantly share code, notes, and snippets.

@halflearned
Last active November 26, 2023 17:18
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save halflearned/bea4e5137c0c81fd18a75f682da466c8 to your computer and use it in GitHub Desktop.
Save halflearned/bea4e5137c0c81fd18a75f682da466c8 to your computer and use it in GitHub Desktop.
Finding subgroups with different treatment effects via grf (K-fold cross-validation version --- time-consuming!)
# Note this is mostly for didactic purposes.
# See https://bookdown.org/halflearned/ml-ci-tutorial/hte-i-binary-treatment.html.
# Read in data
data <- read.csv("https://docs.google.com/uc?id=1kSxrVci_EUcSr_Lg1JKk1l7Xd5I9zfRC&export=download")
# Treatment: does the the gov't spend too much on "welfare" (1) or "assistance to the poor" (0)
treatment <- "w"
# Outcome: 1 for 'yes', 0 for 'no'
outcome <- "y"
# Additional covariates
covariates <- c("age", "polviews", "income", "educ", "marital", "sex")
data <- data[1:1000,]
n <- nrow(data)
# Valid randomized data and observational data with unconfoundedness+overlap.
# Note: read the comments below carefully.
# In randomized settings, do not estimate forest.e and e.hat; use known assignment probs.
fmla <- formula(paste0("~ 0 + ", paste0(covariates, collapse="+")))
X <- model.matrix(fmla, data)
W <- data[,treatment]
Y <- data[,outcome]
n.folds <- 5
indices <- split(seq(n), sort(seq(n) %% n.folds))
num.rankings <- 5
res <- lapply(indices, function(idx) {
# Fit the outcome model on training subset,
# predict on held-out fold.
forest.m <- regression_forest(X[-idx,], Y[-idx])
m.hat <- predict(forest.m, X[idx,])$predictions
# COMMENT / UNCOMMENT AS NEEDED:
# # If assignment probabilities are unknown and need to be estimated,
# # e.g., observational setting with unconfoundedness+overlap:
# forest.e <- regression_forest(X[-idx,], W[-idx], num.trees=1000)
# e.hat <- predict(forest.e, X[idx,])$predictions
# # If probabilities are known,
# # e.g., randomized setting with fixed assignment probabilities.
e.hat <- rep(0.5, length(idx))
# Estimating CATE tau(X) for observations in held-out sample
forest.tau <- causal_forest(X[-idx,], Y[-idx], W[-idx], num.trees = 100)
tau.hat <- predict(forest.tau, X[idx,])$predictions
# Estimating mu.hat(X, 1) and mu.hat(X, 0) for obs in held-out sample
# Note: to understand this, read equations 6-8 in this vignette
# https://grf-labs.github.io/grf/articles/muhats.html
mu.hat.0 <- m.hat - e.hat * tau.hat # E[Y|X,W=0] = E[Y|X] - e(X)*tau(X)
mu.hat.1 <- m.hat + (1 - e.hat) * tau.hat # E[Y|X,W=1] = E[Y|X] + (1 - e(X))*tau(X)
# AIPW scores
aipw.scores <- (tau.hat
+ W[idx] / e.hat * (Y[idx] - mu.hat.1)
- (1 - W[idx]) / (1 - e.hat) * (Y[idx] - mu.hat.0))
# Rank observations on held-out sample based on estimated CATE.
# Subtle but important: we must rank the observations based only on the hold-out sample,
# as opposed to computing all the predictions first and computing an overall ranking.
tau.hat.quantiles <- quantile(tau.hat, probs = seq(0, 1, length.out = num.rankings+1))
ranking <- cut(tau.hat, tau.hat.quantiles, include.lowest = TRUE, labels = paste0("Q", seq(num.rankings)))
# Store results
data.frame(aipw.scores, tau.hat, ranking=factor(ranking), outcome=Y[idx], treatment=W[idx])
})
res <- do.call(rbind, res)
# Average AIPW scores for the treatment effect within each ranking
# Valid in randomized and observational settings with unconfoundedness+overlap.
forest.ate <- lm(aipw.scores ~ 0 + ranking, data=res)
forest.ate <- coeftest(forest.ate, vcov=vcovHC(forest.ate, type='HC2'))
forest.ate <- data.frame("aipw", paste0("Q", seq(num.rankings)), forest.ate[,1:2])
colnames(forest.ate) <- c("method", "ranking", "estimate", "std.err")
rownames(forest.ate) <- NULL
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment