Skip to content

Instantly share code, notes, and snippets.

@apoorvalal
Last active November 2, 2023 22:53
Show Gist options
  • Save apoorvalal/23565bd35e98c8292b4d26237c817fd9 to your computer and use it in GitHub Desktop.
Save apoorvalal/23565bd35e98c8292b4d26237c817fd9 to your computer and use it in GitHub Desktop.
compare bias and variance properties of regression adjustment strategies https://www.degruyter.com/document/doi/10.1515/ijb-2021-0072/html
# %%
pacman::p_load(knitr, tidyverse, DeclareDesign, glmnet)
set.seed(42)
# %% estimator functions
p_hacker = function(data) {
fit_1 = lm_robust(Y ~ Z + X1, data = data)
fit_3 = lm_robust(Y ~ Z + X1 + X2, data = data)
fit_2 = lm_robust(Y ~ Z + X2 + X3 + X4, data = data)
fit_4 = lm_robust(Y ~ Z + X3 + X4 + X5 + X6 + X7 + X8 + X9, data = data)
lowest_p.value_estimate <-
list(fit_1, fit_2, fit_3, fit_4) |>
map_df(tidy) |>
filter(term == "Z") |>
arrange(p.value) |>
slice(1)
}
prognostic_adjust = function(data) {
# fit a predictive model of untreated potential outcome
X = model.matrix(Y ~ -1 + (X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9)^2 , data)
# any good predictive model will do - lasso with all pairwise interactions here
mod = cv.glmnet(X[data$Z == 0,], data$Y[data$Z == 0], keep = TRUE)
data$prognostic_score = predict(mod, X, lambda = mod$lambda.min)
data$prognostic_score[data$Z == 0] = LalRUtils::fitGet(mod) # cross-fit predictions
# run lin regression with prognostic score as only covariate
lm_robust(Y ~ Z * scale(prognostic_score), data) %>% tidy %>% filter(term == "Z")
}
# %%
N = 500; effect_size = 0.1; r = 0.9
Σ = r^toeplitz(1:10)
model = declare_model(N = N,
draw_multivariate( # covariates have decaying covariance with U controlled by r
c(U, X1, X2, X3, X4, X5, X6, X7, X8, X9) ~ MASS::mvrnorm(N, mu = rep(0, 10), Sigma = Σ)
),
# tau = rnorm(N, (effect_size * X2) / 10), # effect heterogeneity
tau = effect_size,
potential_outcomes(Y ~ tau * Z + U + X1 * 0.2 + X8 * X4)
)
# %%
inquiry = declare_inquiry(ATE = mean(Y_Z_1 - Y_Z_0))
data_strategy = declare_assignment(Z = complete_ra(N, m = N/2)) +
declare_measurement(Y = reveal_outcomes(Y ~ Z))
answer_strategy = declare_estimator(Y~Z, .method = lm_robust, .summary = tidy,
term = "Z", inquiry = "ATE", label = "DiM") +
declare_estimator(handler = label_estimator(p_hacker),
inquiry = "ATE", label = "P-hacking") +
declare_estimator(handler = label_estimator(prognostic_adjust),
inquiry = "ATE", label = "prognostic_adj") +
declare_estimator(Y~Z + X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9,
.method = lm_robust, .summary = tidy, term = "Z", inquiry = "ATE", label = "lm_basic") +
declare_estimator(Y~Z,
covariates=~ X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9,
.method = lm_lin, .summary = tidy, term = "Z", inquiry = "ATE", label = "lm_lin")
design = model + inquiry + data_strategy + answer_strategy
diagnosis = diagnose_design(design)
# %%
diagnosis %>% tidy() |>
filter(diagnosand %in% c("bias", "coverage", "rmse", "sd_estimate", "power")) |>
select(estimator, diagnosand, estimate) |>
arrange(diagnosand) |>
kable()
# %%
# |estimator |diagnosand | estimate| std.error|
# |:--------------|:-----------|--------:|---------:|
# |DiM |bias | 0.0166| 0.0065|
# |lm_basic |bias | 0.0138| 0.0049|
# |lm_lin |bias | 0.0141| 0.0049|
# |P-hacking |bias | 0.0298| 0.0057|
# |prognostic_adj |bias | 0.0009| 0.0018|
# |DiM |coverage | 0.9640| 0.0084|
# |lm_basic |coverage | 0.9420| 0.0098|
# |lm_lin |coverage | 0.9420| 0.0098|
# |P-hacking |coverage | 0.9220| 0.0116|
# |prognostic_adj |coverage | 0.9580| 0.0086|
# |DiM |power | 0.1360| 0.0158|
# |lm_basic |power | 0.2000| 0.0205|
# |lm_lin |power | 0.1980| 0.0201|
# |P-hacking |power | 0.2600| 0.0202|
# |prognostic_adj |power | 0.7500| 0.0206|
# |DiM |rmse | 0.1375| 0.0042|
# |lm_basic |rmse | 0.1051| 0.0031|
# |lm_lin |rmse | 0.1054| 0.0032|
# |P-hacking |rmse | 0.1244| 0.0038|
# |prognostic_adj |rmse | 0.0376| 0.0011|
# |DiM |sd_estimate | 0.1366| 0.0040|
# |lm_basic |sd_estimate | 0.1043| 0.0031|
# |lm_lin |sd_estimate | 0.1046| 0.0031|
# |P-hacking |sd_estimate | 0.1209| 0.0040|
# |prognostic_adj |sd_estimate | 0.0376| 0.0011|
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment