Skip to content

Instantly share code, notes, and snippets.

@erikcs
Created December 2, 2021 00:24
Show Gist options
  • Save erikcs/715a26f72763fd702e18bf7a5396568f to your computer and use it in GitHub Desktop.
Save erikcs/715a26f72763fd702e18bf7a5396568f to your computer and use it in GitHub Desktop.
mutau
set.seed(123)
# DGP with overlay, multi_arm_causal_forest splits on (mu,tau)
r = replicate(50, {
n=1600
p=10
X=matrix(runif(n*p),n,p)
e = 0.5
W = rbinom(n, 1, e)
mu = 2*X[,1] - 1
TAU = (1 + exp(-20*(X[, 1:2] - 1/3)))^-1
tau = TAU[,1]*TAU[,2]
Y = mu + W*tau + rnorm(n)
cf.default = causal_forest(X, Y, W, W.hat=0.5,num.trees = 500, Y.hat =NULL)
cf.default.nocent = causal_forest(X, Y, W, W.hat=0.5,num.trees = 500, Y.hat =0)
cf.mutau = multi_arm_causal_forest(X, Y, as.factor(W), W.hat=c(0.5,0.5),num.trees = 500, Y.hat=NULL)
cf.mutau.nocent = multi_arm_causal_forest(X, Y, as.factor(W), W.hat=c(0.5,0.5),num.trees = 500, Y.hat=0)
c(
mse.default = mean((tau - predict(cf.default)$predictions)^2),
mse.default.nocent = mean((tau - predict(cf.default.nocent)$predictions)^2),
mse.mutau = mean((tau - predict(cf.mutau)$predictions)^2),
mse.mutau.nocent = mean((tau - predict(cf.mutau.nocent)$predictions)^2)
)
})
summary(t(r))
# mse.default mse.default.nocent mse.mutau mse.mutau.nocent
# Min. :0.01435 Min. :0.02310 Min. :0.01299 Min. :0.008669
# 1st Qu.:0.03075 1st Qu.:0.05713 1st Qu.:0.02894 1st Qu.:0.022302
# Median :0.03845 Median :0.06136 Median :0.03639 Median :0.032301
# Mean :0.04137 Mean :0.06452 Mean :0.04036 Mean :0.033769
# 3rd Qu.:0.05347 3rd Qu.:0.07518 3rd Qu.:0.05105 3rd Qu.:0.041076
# Max. :0.09690 Max. :0.10560 Max. :0.09473 Max. :0.083380
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment