Created
December 17, 2021 15:33
-
-
Save erikcs/77d9150eea3af0c4fcda75ee8b7a3380 to your computer and use it in GitHub Desktop.
mobVSgrf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# GRF with bivariate splits (done in multi_arm_causal_forest) in two specs (separate installs): | |
# https://github.com/erikcs/grf/commits/mobWY | |
# a) Using what we inferred to be MOB's "~equivalent" formulation of the GRF CART criterion | |
# Install with | |
# devtools::install_github("erikcs/grf", subdir = "r-package/grf", ref = "3ea1ddf") # "mobWY" branch @3ea1ddf Ap="MOB" | |
# b) Using GRF's criterion | |
# Install with | |
# devtools::install_github("erikcs/grf", subdir = "r-package/grf", ref = "c54605a") # "mobWY" branch @c54605a Ap="GRF" | |
# *** | |
# Recap of GRF's criterion for causal forest with an added intercept, (20) in https://arxiv.org/pdf/1610.01271.pdf | |
n = 20 | |
p = 2 | |
x = matrix(runif(n*p),n,p) | |
w = rbinom(n, 1, 0.5) | |
y = x[,1] + w * x[, 2] + rnorm(n) | |
# Define the estimated conditional means to be: | |
Y.hat = grf::regression_forest(x,y,num.trees = 500)$predictions[,] | |
W.hat = 0.5 | |
# Then this is what happens in a parent node with splits on [intercept, tau] (set Y.hat/W.hat to 0 to ignore centering) | |
W = cbind(1, w - W.hat) | |
Y = y - Y.hat | |
beta = solve(t(W) %*% W) %*% t(W) %*% Y # lm(Y ~ W - 1) | |
resid = Y - W %*% beta | |
# n X 2 estimating function matrix: | |
psi = W * c(resid) | |
# GRF weighting matrix Ap (7) in https://arxiv.org/pdf/1610.01271.pdf | |
Ap = t(W) %*% W | |
# Pseudo outcomes rho for CART splitting, psi scaled by Ap, (8) in https://arxiv.org/pdf/1610.01271.pdf | |
# (in standard causal forest \xi would pick out only the tau entry) | |
rho = psi %*% solve(Ap) | |
# CART splitting on this rho is what's implemented in "b)". | |
# For MOB "equivalent" splitting in GRF, what I could infer from the Strasser-Weber statistic after Susanne's help was that it seem to have the same argmax | |
# as the GRF criterion, but with the weighting matrix Ap replaced by: | |
Ap.mob = t(psi) %*% psi | |
# and the rho'is would be scaled by this instead: | |
rho.mob = psi %*% solve(Ap.mob) | |
# CART splitting on this rho is what's implemented in "a)" | |
# **** | |
# 1) Y = mu + (W - 0.5)*tau + rnorm(n) | |
rm(list=ls()) | |
set.seed(123) | |
res1 = replicate(50, { | |
n = 800 | |
p = 10 | |
X = matrix(runif(n*p),n,p) | |
W = rbinom(n, 1, 0.5) | |
mu = 2*X[,1] - 1 | |
tau = 2*X[,2] - 1 | |
Y = mu + (W - 0.5)*tau + rnorm(n) # 1) | |
# Tau test | |
Xt = matrix(runif(1000*p),1000,p) | |
taut = 2*Xt[,2] -1 | |
Y.hat = predict(regression_forest(X,Y,num.trees = 500))$predictions | |
# Causal forest default | |
cf = causal_forest(X, Y, W, W.hat=0.5, num.trees = 500, Y.hat = Y.hat) | |
# GRF with "MOB splitting" | |
cf.mob = multi_arm_causal_forest(X, Y, as.factor(W), W.hat=cbind(0.5, 0.5) ,num.trees = 500, Y.hat = Y.hat) | |
cf.mob.nocenter = multi_arm_causal_forest(X, Y, as.factor(W), W.hat=cbind(0.5, 0.5) ,num.trees = 500, Y.hat = 0) | |
c( | |
mse.cf = mean((taut - predict(cf,Xt)$predictions)^2), | |
mse.cf.mob = mean((taut - predict(cf.mob,Xt)$predictions)^2), | |
mse.cf.mob.nocenter = mean((taut - predict(cf.mob.nocenter,Xt)$predictions)^2) | |
) | |
}) | |
summary(t(res1)) | |
# a) MOB's Ap | |
# mse.cf mse.cf.mob mse.cf.mob.nocenter | |
# Min. :0.01858 Min. :0.01832 Min. :0.02738 | |
# 1st Qu.:0.03417 1st Qu.:0.03305 1st Qu.:0.04519 | |
# Median :0.05237 Median :0.05134 Median :0.06893 | |
# Mean :0.05268 Mean :0.05333 Mean :0.07510 | |
# 3rd Qu.:0.06574 3rd Qu.:0.06732 3rd Qu.:0.09787 | |
# Max. :0.11175 Max. :0.10857 Max. :0.15808 | |
# b) GRF's Ap | |
# mse.cf mse.cf.mob mse.cf.mob.nocenter | |
# Min. :0.01858 Min. :0.01682 Min. :0.02622 | |
# 1st Qu.:0.03417 1st Qu.:0.03320 1st Qu.:0.04381 | |
# Median :0.05237 Median :0.05005 Median :0.06970 | |
# Mean :0.05268 Mean :0.05284 Mean :0.07540 | |
# 3rd Qu.:0.06574 3rd Qu.:0.06687 3rd Qu.:0.09592 | |
# Max. :0.11175 Max. :0.10369 Max. :0.15959 | |
# 2) Y = mu + W*tau + rnorm(n) | |
rm(list=ls()) | |
set.seed(123) | |
res2 = replicate(50, { | |
n = 800 | |
p = 10 | |
X = matrix(runif(n*p),n,p) | |
W = rbinom(n, 1, 0.5) | |
mu = 2*X[,1] - 1 | |
tau = 2*X[,2] - 1 | |
Y = mu + W*tau + rnorm(n) # 2) | |
# Tau test | |
Xt = matrix(runif(1000*p),1000,p) | |
taut = 2*Xt[,2] -1 | |
Y.hat = predict(regression_forest(X,Y,num.trees = 500))$predictions | |
# Causal forest default | |
cf = causal_forest(X, Y, W, W.hat=0.5, num.trees = 500, Y.hat = Y.hat) | |
# GRF with "MOB splitting" | |
cf.mob = multi_arm_causal_forest(X, Y, as.factor(W), W.hat=cbind(0.5, 0.5) ,num.trees = 500, Y.hat = Y.hat) | |
cf.mob.nocenter = multi_arm_causal_forest(X, Y, as.factor(W), W.hat=cbind(0.5, 0.5) ,num.trees = 500, Y.hat = 0) | |
c( | |
mse.cf = mean((taut - predict(cf,Xt)$predictions)^2), | |
mse.cf.mob = mean((taut - predict(cf.mob,Xt)$predictions)^2), | |
mse.cf.mob.nocenter = mean((taut - predict(cf.mob.nocenter,Xt)$predictions)^2) | |
) | |
}) | |
summary(t(res2)) | |
# a) MOB's Ap | |
# mse.cf mse.cf.mob mse.cf.mob.nocenter | |
# Min. :0.01937 Min. :0.01740 Min. :0.02481 | |
# 1st Qu.:0.03282 1st Qu.:0.03355 1st Qu.:0.05239 | |
# Median :0.05259 Median :0.05176 Median :0.07469 | |
# Mean :0.05319 Mean :0.05353 Mean :0.08175 | |
# 3rd Qu.:0.06684 3rd Qu.:0.06497 3rd Qu.:0.10764 | |
# Max. :0.11283 Max. :0.12930 Max. :0.18221 | |
# | |
# b) GRF's Ap | |
# mse.cf mse.cf.mob mse.cf.mob.nocenter | |
# Min. :0.01937 Min. :0.01659 Min. :0.01887 | |
# 1st Qu.:0.03282 1st Qu.:0.03273 1st Qu.:0.04115 | |
# Median :0.05259 Median :0.05122 Median :0.06266 | |
# Mean :0.05319 Mean :0.05202 Mean :0.06519 | |
# 3rd Qu.:0.06684 3rd Qu.:0.06406 3rd Qu.:0.07670 | |
# Max. :0.11283 Max. :0.10874 Max. :0.14016 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment