Skip to content

Instantly share code, notes, and snippets.

@erikcs
Created December 17, 2021 15:33
Show Gist options
  • Save erikcs/77d9150eea3af0c4fcda75ee8b7a3380 to your computer and use it in GitHub Desktop.
Save erikcs/77d9150eea3af0c4fcda75ee8b7a3380 to your computer and use it in GitHub Desktop.
mobVSgrf
# GRF with bivariate splits (done in multi_arm_causal_forest) in two specs (separate installs):
# https://github.com/erikcs/grf/commits/mobWY
# a) Using what we inferred to be MOB's "~equivalent" formulation of the GRF CART criterion
# Install with
# devtools::install_github("erikcs/grf", subdir = "r-package/grf", ref = "3ea1ddf") # "mobWY" branch @3ea1ddf Ap="MOB"
# b) Using GRF's criterion
# Install with
# devtools::install_github("erikcs/grf", subdir = "r-package/grf", ref = "c54605a") # "mobWY" branch @c54605a Ap="GRF"
# ***
# Recap of GRF's criterion for causal forest with an added intercept, (20) in https://arxiv.org/pdf/1610.01271.pdf
n = 20
p = 2
x = matrix(runif(n*p),n,p)
w = rbinom(n, 1, 0.5)
y = x[,1] + w * x[, 2] + rnorm(n)
# Define the estimated conditional means to be:
Y.hat = grf::regression_forest(x,y,num.trees = 500)$predictions[,]
W.hat = 0.5
# Then this is what happens in a parent node with splits on [intercept, tau] (set Y.hat/W.hat to 0 to ignore centering)
W = cbind(1, w - W.hat)
Y = y - Y.hat
beta = solve(t(W) %*% W) %*% t(W) %*% Y # lm(Y ~ W - 1)
resid = Y - W %*% beta
# n X 2 estimating function matrix:
psi = W * c(resid)
# GRF weighting matrix Ap (7) in https://arxiv.org/pdf/1610.01271.pdf
Ap = t(W) %*% W
# Pseudo outcomes rho for CART splitting, psi scaled by Ap, (8) in https://arxiv.org/pdf/1610.01271.pdf
# (in standard causal forest \xi would pick out only the tau entry)
rho = psi %*% solve(Ap)
# CART splitting on this rho is what's implemented in "b)".
# For MOB "equivalent" splitting in GRF, what I could infer from the Strasser-Weber statistic after Susanne's help was that it seem to have the same argmax
# as the GRF criterion, but with the weighting matrix Ap replaced by:
Ap.mob = t(psi) %*% psi
# and the rho'is would be scaled by this instead:
rho.mob = psi %*% solve(Ap.mob)
# CART splitting on this rho is what's implemented in "a)"
# ****
# 1) Y = mu + (W - 0.5)*tau + rnorm(n)
rm(list=ls())
set.seed(123)
res1 = replicate(50, {
n = 800
p = 10
X = matrix(runif(n*p),n,p)
W = rbinom(n, 1, 0.5)
mu = 2*X[,1] - 1
tau = 2*X[,2] - 1
Y = mu + (W - 0.5)*tau + rnorm(n) # 1)
# Tau test
Xt = matrix(runif(1000*p),1000,p)
taut = 2*Xt[,2] -1
Y.hat = predict(regression_forest(X,Y,num.trees = 500))$predictions
# Causal forest default
cf = causal_forest(X, Y, W, W.hat=0.5, num.trees = 500, Y.hat = Y.hat)
# GRF with "MOB splitting"
cf.mob = multi_arm_causal_forest(X, Y, as.factor(W), W.hat=cbind(0.5, 0.5) ,num.trees = 500, Y.hat = Y.hat)
cf.mob.nocenter = multi_arm_causal_forest(X, Y, as.factor(W), W.hat=cbind(0.5, 0.5) ,num.trees = 500, Y.hat = 0)
c(
mse.cf = mean((taut - predict(cf,Xt)$predictions)^2),
mse.cf.mob = mean((taut - predict(cf.mob,Xt)$predictions)^2),
mse.cf.mob.nocenter = mean((taut - predict(cf.mob.nocenter,Xt)$predictions)^2)
)
})
summary(t(res1))
# a) MOB's Ap
# mse.cf mse.cf.mob mse.cf.mob.nocenter
# Min. :0.01858 Min. :0.01832 Min. :0.02738
# 1st Qu.:0.03417 1st Qu.:0.03305 1st Qu.:0.04519
# Median :0.05237 Median :0.05134 Median :0.06893
# Mean :0.05268 Mean :0.05333 Mean :0.07510
# 3rd Qu.:0.06574 3rd Qu.:0.06732 3rd Qu.:0.09787
# Max. :0.11175 Max. :0.10857 Max. :0.15808
# b) GRF's Ap
# mse.cf mse.cf.mob mse.cf.mob.nocenter
# Min. :0.01858 Min. :0.01682 Min. :0.02622
# 1st Qu.:0.03417 1st Qu.:0.03320 1st Qu.:0.04381
# Median :0.05237 Median :0.05005 Median :0.06970
# Mean :0.05268 Mean :0.05284 Mean :0.07540
# 3rd Qu.:0.06574 3rd Qu.:0.06687 3rd Qu.:0.09592
# Max. :0.11175 Max. :0.10369 Max. :0.15959
# 2) Y = mu + W*tau + rnorm(n)
rm(list=ls())
set.seed(123)
res2 = replicate(50, {
n = 800
p = 10
X = matrix(runif(n*p),n,p)
W = rbinom(n, 1, 0.5)
mu = 2*X[,1] - 1
tau = 2*X[,2] - 1
Y = mu + W*tau + rnorm(n) # 2)
# Tau test
Xt = matrix(runif(1000*p),1000,p)
taut = 2*Xt[,2] -1
Y.hat = predict(regression_forest(X,Y,num.trees = 500))$predictions
# Causal forest default
cf = causal_forest(X, Y, W, W.hat=0.5, num.trees = 500, Y.hat = Y.hat)
# GRF with "MOB splitting"
cf.mob = multi_arm_causal_forest(X, Y, as.factor(W), W.hat=cbind(0.5, 0.5) ,num.trees = 500, Y.hat = Y.hat)
cf.mob.nocenter = multi_arm_causal_forest(X, Y, as.factor(W), W.hat=cbind(0.5, 0.5) ,num.trees = 500, Y.hat = 0)
c(
mse.cf = mean((taut - predict(cf,Xt)$predictions)^2),
mse.cf.mob = mean((taut - predict(cf.mob,Xt)$predictions)^2),
mse.cf.mob.nocenter = mean((taut - predict(cf.mob.nocenter,Xt)$predictions)^2)
)
})
summary(t(res2))
# a) MOB's Ap
# mse.cf mse.cf.mob mse.cf.mob.nocenter
# Min. :0.01937 Min. :0.01740 Min. :0.02481
# 1st Qu.:0.03282 1st Qu.:0.03355 1st Qu.:0.05239
# Median :0.05259 Median :0.05176 Median :0.07469
# Mean :0.05319 Mean :0.05353 Mean :0.08175
# 3rd Qu.:0.06684 3rd Qu.:0.06497 3rd Qu.:0.10764
# Max. :0.11283 Max. :0.12930 Max. :0.18221
#
# b) GRF's Ap
# mse.cf mse.cf.mob mse.cf.mob.nocenter
# Min. :0.01937 Min. :0.01659 Min. :0.01887
# 1st Qu.:0.03282 1st Qu.:0.03273 1st Qu.:0.04115
# Median :0.05259 Median :0.05122 Median :0.06266
# Mean :0.05319 Mean :0.05202 Mean :0.06519
# 3rd Qu.:0.06684 3rd Qu.:0.06406 3rd Qu.:0.07670
# Max. :0.11283 Max. :0.10874 Max. :0.14016
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment