Created
December 2, 2021 00:24
-
-
Save erikcs/715a26f72763fd702e18bf7a5396568f to your computer and use it in GitHub Desktop.
mutau
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
set.seed(123) | |
# DGP with overlay, multi_arm_causal_forest splits on (mu,tau) | |
r = replicate(50, { | |
n=1600 | |
p=10 | |
X=matrix(runif(n*p),n,p) | |
e = 0.5 | |
W = rbinom(n, 1, e) | |
mu = 2*X[,1] - 1 | |
TAU = (1 + exp(-20*(X[, 1:2] - 1/3)))^-1 | |
tau = TAU[,1]*TAU[,2] | |
Y = mu + W*tau + rnorm(n) | |
cf.default = causal_forest(X, Y, W, W.hat=0.5,num.trees = 500, Y.hat =NULL) | |
cf.default.nocent = causal_forest(X, Y, W, W.hat=0.5,num.trees = 500, Y.hat =0) | |
cf.mutau = multi_arm_causal_forest(X, Y, as.factor(W), W.hat=c(0.5,0.5),num.trees = 500, Y.hat=NULL) | |
cf.mutau.nocent = multi_arm_causal_forest(X, Y, as.factor(W), W.hat=c(0.5,0.5),num.trees = 500, Y.hat=0) | |
c( | |
mse.default = mean((tau - predict(cf.default)$predictions)^2), | |
mse.default.nocent = mean((tau - predict(cf.default.nocent)$predictions)^2), | |
mse.mutau = mean((tau - predict(cf.mutau)$predictions)^2), | |
mse.mutau.nocent = mean((tau - predict(cf.mutau.nocent)$predictions)^2) | |
) | |
}) | |
summary(t(r)) | |
# mse.default mse.default.nocent mse.mutau mse.mutau.nocent | |
# Min. :0.01435 Min. :0.02310 Min. :0.01299 Min. :0.008669 | |
# 1st Qu.:0.03075 1st Qu.:0.05713 1st Qu.:0.02894 1st Qu.:0.022302 | |
# Median :0.03845 Median :0.06136 Median :0.03639 Median :0.032301 | |
# Mean :0.04137 Mean :0.06452 Mean :0.04036 Mean :0.033769 | |
# 3rd Qu.:0.05347 3rd Qu.:0.07518 3rd Qu.:0.05105 3rd Qu.:0.041076 | |
# Max. :0.09690 Max. :0.10560 Max. :0.09473 Max. :0.083380 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment