Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save halflearned/30a1e1f5b81cb95b97fb7687252e725b to your computer and use it in GitHub Desktop.
Save halflearned/30a1e1f5b81cb95b97fb7687252e725b to your computer and use it in GitHub Desktop.
grf::causal_forest with low overlap
library(grf)
n <- 100
# X ~ Unif[-5,-2]+[2, 5]
X <- matrix(c(seq(-5, -2, length.out=n/2), seq(2, 5, length.out=n/2)), n, 1)
# P( W=1 | X > 0) = 1 and P( W=1 | X < 0) = 1
W <- matrix(X[,1] > 0, n, 1)
# Some model
Y <- sin(X*W)
# What does W.hat look like?
cf_w <- grf::regression_forest(X=X, Y=W)
W.hat <- predict(cf_w)$predictions
plot(X, W.hat)
# Truth vs estimated
cf <- grf::causal_forest(X, Y, W) # No error
taux <- sin(X) - sin(0)
tauhatx <- predict(cf)$predictions
plot(taux, tauhatx)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment