This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
causal_survival_forest.custom <- function( | |
X, Y, W, D, | |
W.hat = NULL, | |
target = c("RMST", "survival.probability"), | |
horizon = NULL, | |
failure.times = NULL, | |
num.trees = 2000, | |
sample.weights = NULL, | |
clusters = NULL, | |
equalize.cluster.weights = FALSE, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# To use this file include | |
# source("best_linear_projection_overlap.R") | |
# after | |
# library(grf) | |
best_linear_projection <- function(forest, | |
A = NULL, | |
subset = NULL, | |
debiasing.weights = NULL, | |
compliance.score = NULL, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
rm(list = ls()) | |
library(grf) | |
coeffs = replicate(250, { | |
# setup from | |
# https://tompepinsky.com/2022/06/07/illustrating-contamination-bias-with-implications-for-intersectional-description/ | |
effect1 = 1 | |
effect2 = 2 | |
effect3 = 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
true = mean(pmax(0.5 * rnorm(1e5), 0)) | |
r = replicate(100, { | |
n <- 2000 | |
p <- 5 | |
X <- matrix(rnorm(n * p), n, p) | |
Z <- runif(n, -4, 4) | |
cutoff <- 0 | |
W <- as.numeric(Z >= cutoff) | |
tau <- pmax(0.5 * X[, 1], 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# GRF with bivariate splits (done in multi_arm_causal_forest) in two specs (separate installs): | |
# https://github.com/erikcs/grf/commits/mobWY | |
# a) Using what we inferred to be MOB's "~equivalent" formulation of the GRF CART criterion | |
# Install with | |
# devtools::install_github("erikcs/grf", subdir = "r-package/grf", ref = "3ea1ddf") # "mobWY" branch @3ea1ddf Ap="MOB" | |
# b) Using GRF's criterion | |
# Install with | |
# devtools::install_github("erikcs/grf", subdir = "r-package/grf", ref = "c54605a") # "mobWY" branch @c54605a Ap="GRF" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
set.seed(123) | |
# DGP with overlay, multi_arm_causal_forest splits on (mu,tau) | |
r = replicate(50, { | |
n=1600 | |
p=10 | |
X=matrix(runif(n*p),n,p) | |
e = 0.5 | |
W = rbinom(n, 1, e) | |
mu = 2*X[,1] - 1 | |
TAU = (1 + exp(-20*(X[, 1:2] - 1/3)))^-1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
xtabs(val ~ dgp + estimator + metric + target, DF.table.long, subset = n == 2000) | |
, , metric = CLF, target = RMST | |
estimator | |
dgp VT SRC1 SRC2 IPCW CSF | |
type1 0.26015800 0.23058600 0.25583400 0.21411400 0.19654800 | |
type2 0.26030400 0.21983600 0.24268800 0.21657200 0.14402200 | |
type3 0.19990800 0.17275800 0.20015600 0.08673400 0.08790600 | |
type4 0.19212837 0.13838581 0.15726132 0.03206281 0.02763727 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### compare with other ways of building a depth 3 tree | |
library(policytree) | |
rm(list = ls()) | |
# set.seed(20) | |
n = 1000 | |
p = 2 | |
d = 3 | |
X = round(matrix(rnorm(n*p),n,p), 2) | |
Y = matrix(rnorm(n*d), n, d) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### compare with other ways of building a depth 3 tree | |
library(policytree) | |
rm(list = ls()) | |
# set.seed(20) | |
n = 1000 | |
p = 2 | |
d = 3 | |
X = round(matrix(rnorm(n*p),n,p), 2) | |
Y = matrix(rnorm(n*d), n, d) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Master | |
oob test stdratio | |
Y.error "0.047 +/- 0.001" "0.047 +/- 0.001" "2.471 +/- 0.087" | |
tau.error "0.009 +/- 0.001" "0.009 +/- 0.001" "1.251 +/- 0.099" | |
csf.error "0.105 +/- 0.011" "0.104 +/- 0.011" "1.600 +/- 0.201" | |
mcf.error "0.047 +/- 0.002" "0.048 +/- 0.003" "2.735 +/- 0.143" | |
user.self sys.self elapsed | |
Y.time "18.138 +/- 0.057" "0.309 +/- 0.010" "18.496 +/- 0.059" | |
Y.time.pred "3.225 +/- 0.039" "0.107 +/- 0.005" "3.340 +/- 0.042" | |
Y.time.ci "5.135 +/- 0.074" "0.200 +/- 0.009" "5.348 +/- 0.074" |
NewerOlder