Skip to content

Instantly share code, notes, and snippets.

causal_survival_forest.custom <- function(
X, Y, W, D,
W.hat = NULL,
target = c("RMST", "survival.probability"),
horizon = NULL,
failure.times = NULL,
num.trees = 2000,
sample.weights = NULL,
clusters = NULL,
equalize.cluster.weights = FALSE,
# To use this file include
# source("best_linear_projection_overlap.R")
# after
# library(grf)
best_linear_projection <- function(forest,
A = NULL,
subset = NULL,
debiasing.weights = NULL,
compliance.score = NULL,
@erikcs
erikcs / nonparam.R
Created June 14, 2022 22:29
nonparam.R
rm(list = ls())
library(grf)
coeffs = replicate(250, {
# setup from
# https://tompepinsky.com/2022/06/07/illustrating-contamination-bias-with-implications-for-intersectional-description/
effect1 = 1
effect2 = 2
effect3 = 3
true = mean(pmax(0.5 * rnorm(1e5), 0))
r = replicate(100, {
n <- 2000
p <- 5
X <- matrix(rnorm(n * p), n, p)
Z <- runif(n, -4, 4)
cutoff <- 0
W <- as.numeric(Z >= cutoff)
tau <- pmax(0.5 * X[, 1], 0)
@erikcs
erikcs / mobVSgrf.R
Created December 17, 2021 15:33
mobVSgrf
# GRF with bivariate splits (done in multi_arm_causal_forest) in two specs (separate installs):
# https://github.com/erikcs/grf/commits/mobWY
# a) Using what we inferred to be MOB's "~equivalent" formulation of the GRF CART criterion
# Install with
# devtools::install_github("erikcs/grf", subdir = "r-package/grf", ref = "3ea1ddf") # "mobWY" branch @3ea1ddf Ap="MOB"
# b) Using GRF's criterion
# Install with
# devtools::install_github("erikcs/grf", subdir = "r-package/grf", ref = "c54605a") # "mobWY" branch @c54605a Ap="GRF"
set.seed(123)
# DGP with overlay, multi_arm_causal_forest splits on (mu,tau)
r = replicate(50, {
n=1600
p=10
X=matrix(runif(n*p),n,p)
e = 0.5
W = rbinom(n, 1, e)
mu = 2*X[,1] - 1
TAU = (1 + exp(-20*(X[, 1:2] - 1/3)))^-1
@erikcs
erikcs / d163bf7.txt
Created November 20, 2021 18:56
d163bf7.txt
xtabs(val ~ dgp + estimator + metric + target, DF.table.long, subset = n == 2000)
, , metric = CLF, target = RMST
estimator
dgp VT SRC1 SRC2 IPCW CSF
type1 0.26015800 0.23058600 0.25583400 0.21411400 0.19654800
type2 0.26030400 0.21983600 0.24268800 0.21657200 0.14402200
type3 0.19990800 0.17275800 0.20015600 0.08673400 0.08790600
type4 0.19212837 0.13838581 0.15726132 0.03206281 0.02763727
@erikcs
erikcs / Hybrid_R_splitCompare2.R
Created August 14, 2021 04:42
Hybrid_R_splitCompare2.R
### compare with other ways of building a depth 3 tree
library(policytree)
rm(list = ls())
# set.seed(20)
n = 1000
p = 2
d = 3
X = round(matrix(rnorm(n*p),n,p), 2)
Y = matrix(rnorm(n*d), n, d)
@erikcs
erikcs / Hybrid_R_splitCompare1.R
Last active August 14, 2021 05:10
Hybrid_R_splitCompare1.R
### compare with other ways of building a depth 3 tree
library(policytree)
rm(list = ls())
# set.seed(20)
n = 1000
p = 2
d = 3
X = round(matrix(rnorm(n*p),n,p), 2)
Y = matrix(rnorm(n*d), n, d)
# Master
oob test stdratio
Y.error "0.047 +/- 0.001" "0.047 +/- 0.001" "2.471 +/- 0.087"
tau.error "0.009 +/- 0.001" "0.009 +/- 0.001" "1.251 +/- 0.099"
csf.error "0.105 +/- 0.011" "0.104 +/- 0.011" "1.600 +/- 0.201"
mcf.error "0.047 +/- 0.002" "0.048 +/- 0.003" "2.735 +/- 0.143"
user.self sys.self elapsed
Y.time "18.138 +/- 0.057" "0.309 +/- 0.010" "18.496 +/- 0.059"
Y.time.pred "3.225 +/- 0.039" "0.107 +/- 0.005" "3.340 +/- 0.042"
Y.time.ci "5.135 +/- 0.074" "0.200 +/- 0.009" "5.348 +/- 0.074"