Skip to content

Instantly share code, notes, and snippets.

@halflearned
Created November 12, 2018 04:16
Show Gist options
  • Save halflearned/80744e49087266bf8406057821261d84 to your computer and use it in GitHub Desktop.
Save halflearned/80744e49087266bf8406057821261d84 to your computer and use it in GitHub Desktop.
alpha and imbalance.penalty don't matter
library(grf)
p = 3
n = 2000
sigma = 0.1
X = matrix(2 * runif(n * p) - 1, n, p)
W = rbinom(n, 1, 0.1)
TAU = (X[,1] > 0)
Y = TAU * (W - 1/2) + sigma * rnorm(n)
W.forest = regression_forest(X, W, num.trees = 500, seed=1234)
W.hat = predict(W.forest)$predictions
Y.forest = regression_forest(X, Y, num.trees = 500, seed=1234)
Y.hat = predict(Y.forest)$predictions
Y.resid = Y - Y.hat
W.resid = W - W.hat
print(mean(Y.resid, na.rm=TRUE))
for (i in seq(100)) {
alpha = runif(1, min = 0.01, max=0.24)
imbalance.penalty= runif(1, min = 0.01, max=1000)
cf <- causal_forest(X, Y, W, Y.hat = Y.hat, W.hat = W.hat,
num.trees = 200, tune.parameters = FALSE,
sample.fraction=0.5,
alpha=alpha,
imbalance.penalty = imbalance.penalty,
min.node.size = 1,
stabilize.splits = TRUE,
seed = 12345)
deb_error <- predict(cf)$debiased.error
avg_deb_error <- mean(deb_error, na.rm=TRUE)
print("ALPHA: ")
print(alpha)
print("IMBALANCE PENALTY: ")
print(imbalance.penalty)
print("AVG DEBIASED ERROR:")
print(avg_deb_error)
cat("\n")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment