Skip to content

Instantly share code, notes, and snippets.

@halflearned
Last active May 23, 2018 20:08
Show Gist options
  • Save halflearned/b727be706395d8c9dec027b946cf4bc1 to your computer and use it in GitHub Desktop.
Save halflearned/b727be706395d8c9dec027b946cf4bc1 to your computer and use it in GitHub Desktop.
# Reference: https://github.com/swager/grf/issues/247
# Here I shut off nonlinearity in true model, just for diagnostics
library(tidyverse)
library(grf)
simple_oracle_data <- function(nsamp,
beta=c(2,-1,0,0,0),
te=1,
hete=1,
seed=3) {
set.seed(seed)
x <- mvtnorm::rmvnorm(nsamp, mean=rep(0, length(beta)))
s <- rbinom(n=nsamp, size=1, prob=0.5)
t <- rbinom(n=nsamp, size=1, prob=0.5)
# x drives main effects, s drives HETE
z <- x %*% beta + t*te + s*t*hete
# each example gets Z(t=0)
z_ctl <- x %*% beta
# each example gets Z(t=1)
z_tmt <- x %*% beta + s*hete + te
# Create tibble
df <- tibble::as_tibble(x)
# Rename "V#" to "x#" for numbered covariates
names(df) <- gsub("V", "x", names(df))
# Add subgroup indicator to df
df[, "s"] <- as.double(s)
# Add treatment indicator to df
df[, "t"] <- as.factor(t)
df[, "y"] <- as.double(z + rnorm(nsamp)) #as.double(rbinom(n=nsamp, size=1, prob=plogis(z)))
df[, "p_true"] <- as.vector(plogis(z))
df[, "p_ctl"] <- as.vector(plogis(z_ctl))
df[, "p_tmt"] <- as.vector(plogis(z_tmt))
df[, "te_true"] <- as.double(z_tmt - z_ctl) #as.vector(plogis(z_tmt) - plogis(z_ctl))
list("data" = df)
}
oracle_test <- function(ntrain=1000,
ntest=10000,
add_hete=TRUE,
seed=7) {
odtrain <- simple_oracle_data(nsamp = ntrain, seed = seed)
# Pull out X, W, and Y
xtrain <- model.matrix(~ .,
data = select(odtrain$data, starts_with("x"),
starts_with("s")) )[, -1]
wtrain <- as.numeric(odtrain$data$t)
ytrain <- as.numeric(odtrain$data$y)
# Train the model
cf <-causal_forest(X=xtrain, Y=ytrain, W=wtrain, seed=seed) # HERE
# Get test data
odtest <- simple_oracle_data(nsamp = ntest, seed = seed + 1)
# Pull out X, W, and Y
xtest <- model.matrix(~ ., data=select(odtest$data, starts_with("x"),
starts_with("s")))[, -1]
ytest <- as.numeric(odtest$data$y)
ypred <- predict(cf, xtest)
cbind(tibble(ytrue = odtest$data$y, tetrue = odtest$data$te_true,
ypred = as.numeric(ypred$predictions)),
select(odtest$data, starts_with("s")))
}
ot <- oracle_test(seed=1)
corr <- cor(ot$tetrue, ot$ypred)
print("Correlation")
print(corr) # Around 0.96
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment