Skip to content

Instantly share code, notes, and snippets.

@apoorvalal
Created August 12, 2022 02:33
Show Gist options
  • Save apoorvalal/e01a820f29afb9bf434a42ede57cebf1 to your computer and use it in GitHub Desktop.
Save apoorvalal/e01a820f29afb9bf434a42ede57cebf1 to your computer and use it in GitHub Desktop.
Omnibus tests of treatment effect heterogeneity with linear and nonlinear hetfx estimators.
rm(list = ls())
libreq(data.table, estimatr,
grf,
DoubleML, mlr3, mlr3learners, dmlUtils)
# %% linear effect heterogeneity
dfm_omnibus = function(y, w, X){
n1 = sum(w); n0 = sum(1-w); K = ncol(X)
# separate outcome models
m1 = lm.fit(X[w==1,], y[w==1]); m0 = lm.fit(X[w==0,], y[w==0])
E1 = m1$residuals*X[w == 1,]; E0 = m0$residuals*X[w == 0,]
# vcov
Sxx1Inv = solve(crossprod(X[w==1,])/n1); Sxx0Inv = solve(crossprod(X[w==0,])/n0)
# projection of effect heterogeneity on covariates
betaHat = m1$coefficients - m0$coefficients
covBeta = (Sxx1Inv%*%(cov(E1)/n1)%*%Sxx1Inv + Sxx0Inv%*%(cov(E0)/n0)%*%Sxx0Inv )
beta1Hat = betaHat[2:K]; covBeta1 = covBeta[2:K, 2:K]
# joint test on non-intercept subvector
chisq.stat = t(beta1Hat) %*% solve(covBeta1, beta1Hat)
chisq.pv = pchisq(chisq.stat, df = K-1, lower.tail = FALSE)
c(chisq.stat , chisq.pv)
}
# takes double-robust scores and does identical test
dfm_omnibus2 = function(tau, X){
K = ncol(X)
m = lm_robust(tau~X-1)
# joint test on non-intercept subvector
beta1Hat = coef(m)[2:K]
covBeta1 = vcov(m)[2:K, 2:K]
chisq.stat = t(beta1Hat) %*% solve(covBeta1, beta1Hat)
chisq.pv = pchisq(chisq.stat, df = K-1, lower.tail = FALSE)
c(chisq.stat , chisq.pv)
}
# %% dgp fn
dgp = \(n = 1e4, p = 6,
τf = \(x) 1/3 ,
y0f = \(x) 4*pmax(x[1] + x[2], 0) + sin(x[5]) * pmax(x[6], 0.5),
πf = \(x) x[1] - 0.5 * x[3] + 3*pmax(x[4], 0)
){
X = matrix(runif(n*p, -2, 2), n, p)
# generate treatment, heterogeneity, baseline
W = rbinom(n, 1, plogis(apply(X, 1, πf)))
τ = apply(X, 1, τf)
Y0 = apply(X, 1, y0f)
Y = Y0 + W * τ + rnorm(n)
list(y = Y, w = W, X = cbind(1, X))
}
# %%
lgr::get_logger("mlr3")$set_threshold("warn")
lasso = lrn("regr.cv_glmnet", nfolds = 5, s = "lambda.min"); set_threads(lasso)
lasso_class = lrn("classif.cv_glmnet", nfolds = 5, s = "lambda.min"); set_threads(lasso_class)
# %%
oneRep = \(...){
df2 = dgp(...)
# grf
cf = with(df2, causal_forest(X, y, w, num.threads = 1))
tau_1 = get_scores(cf)
# dml lasso
data_ml = with(df2, double_ml_data_from_matrix(X, y, w))
dml_irm = DoubleMLIRM$new(data_ml, ml_g = lasso, ml_m = lasso_class, score = 'ATE')
dml_irm$fit(store_predictions=TRUE)
tau3 = dml_irm$psi_b
c(
lin_pval = with(df2, dfm_omnibus(y, w, X)[2]),
rf_pval = dfm_omnibus2(tau_1, df2$X)[2],
dml_pval = dfm_omnibus2(tau3, df2$X)[2]
)
}
oneRep(τf = \(x) x[2]^2, πf = \(x) 1/3, y0f = \(x) 3 * x[3] * sin(x[1]))
# %%
hetFX = mcReplicate(100, {
oneRep(τf = \(x) x[2]^2,
πf = \(x) 1/3,
y0f = \(x) 3 * x[3] * sin(x[1]))
}, mc.cores = 6
)
# %%
homFX = mcReplicate(100, {
oneRep(τf = \(x) 0.1,
πf = \(x) 1/3,
y0f = \(x) 3 * x[3] * sin(x[1]))
}, mc.cores = 6
)
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment