Created
August 12, 2022 02:33
-
-
Save apoorvalal/e01a820f29afb9bf434a42ede57cebf1 to your computer and use it in GitHub Desktop.
Omnibus tests of treatment effect heterogeneity with linear and nonlinear hetfx estimators.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
rm(list = ls()) | |
libreq(data.table, estimatr, | |
grf, | |
DoubleML, mlr3, mlr3learners, dmlUtils) | |
# %% linear effect heterogeneity | |
dfm_omnibus = function(y, w, X){ | |
n1 = sum(w); n0 = sum(1-w); K = ncol(X) | |
# separate outcome models | |
m1 = lm.fit(X[w==1,], y[w==1]); m0 = lm.fit(X[w==0,], y[w==0]) | |
E1 = m1$residuals*X[w == 1,]; E0 = m0$residuals*X[w == 0,] | |
# vcov | |
Sxx1Inv = solve(crossprod(X[w==1,])/n1); Sxx0Inv = solve(crossprod(X[w==0,])/n0) | |
# projection of effect heterogeneity on covariates | |
betaHat = m1$coefficients - m0$coefficients | |
covBeta = (Sxx1Inv%*%(cov(E1)/n1)%*%Sxx1Inv + Sxx0Inv%*%(cov(E0)/n0)%*%Sxx0Inv ) | |
beta1Hat = betaHat[2:K]; covBeta1 = covBeta[2:K, 2:K] | |
# joint test on non-intercept subvector | |
chisq.stat = t(beta1Hat) %*% solve(covBeta1, beta1Hat) | |
chisq.pv = pchisq(chisq.stat, df = K-1, lower.tail = FALSE) | |
c(chisq.stat , chisq.pv) | |
} | |
# takes double-robust scores and does identical test | |
dfm_omnibus2 = function(tau, X){ | |
K = ncol(X) | |
m = lm_robust(tau~X-1) | |
# joint test on non-intercept subvector | |
beta1Hat = coef(m)[2:K] | |
covBeta1 = vcov(m)[2:K, 2:K] | |
chisq.stat = t(beta1Hat) %*% solve(covBeta1, beta1Hat) | |
chisq.pv = pchisq(chisq.stat, df = K-1, lower.tail = FALSE) | |
c(chisq.stat , chisq.pv) | |
} | |
# %% dgp fn | |
dgp = \(n = 1e4, p = 6, | |
τf = \(x) 1/3 , | |
y0f = \(x) 4*pmax(x[1] + x[2], 0) + sin(x[5]) * pmax(x[6], 0.5), | |
πf = \(x) x[1] - 0.5 * x[3] + 3*pmax(x[4], 0) | |
){ | |
X = matrix(runif(n*p, -2, 2), n, p) | |
# generate treatment, heterogeneity, baseline | |
W = rbinom(n, 1, plogis(apply(X, 1, πf))) | |
τ = apply(X, 1, τf) | |
Y0 = apply(X, 1, y0f) | |
Y = Y0 + W * τ + rnorm(n) | |
list(y = Y, w = W, X = cbind(1, X)) | |
} | |
# %% | |
lgr::get_logger("mlr3")$set_threshold("warn") | |
lasso = lrn("regr.cv_glmnet", nfolds = 5, s = "lambda.min"); set_threads(lasso) | |
lasso_class = lrn("classif.cv_glmnet", nfolds = 5, s = "lambda.min"); set_threads(lasso_class) | |
# %% | |
oneRep = \(...){ | |
df2 = dgp(...) | |
# grf | |
cf = with(df2, causal_forest(X, y, w, num.threads = 1)) | |
tau_1 = get_scores(cf) | |
# dml lasso | |
data_ml = with(df2, double_ml_data_from_matrix(X, y, w)) | |
dml_irm = DoubleMLIRM$new(data_ml, ml_g = lasso, ml_m = lasso_class, score = 'ATE') | |
dml_irm$fit(store_predictions=TRUE) | |
tau3 = dml_irm$psi_b | |
c( | |
lin_pval = with(df2, dfm_omnibus(y, w, X)[2]), | |
rf_pval = dfm_omnibus2(tau_1, df2$X)[2], | |
dml_pval = dfm_omnibus2(tau3, df2$X)[2] | |
) | |
} | |
oneRep(τf = \(x) x[2]^2, πf = \(x) 1/3, y0f = \(x) 3 * x[3] * sin(x[1])) | |
# %% | |
hetFX = mcReplicate(100, { | |
oneRep(τf = \(x) x[2]^2, | |
πf = \(x) 1/3, | |
y0f = \(x) 3 * x[3] * sin(x[1])) | |
}, mc.cores = 6 | |
) | |
# %% | |
homFX = mcReplicate(100, { | |
oneRep(τf = \(x) 0.1, | |
πf = \(x) 1/3, | |
y0f = \(x) 3 * x[3] * sin(x[1])) | |
}, mc.cores = 6 | |
) | |
# %% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment