Skip to content

Instantly share code, notes, and snippets.

@erikcs
Last active August 14, 2021 05:10
Show Gist options
  • Save erikcs/e5cb80f657d960df4d3279efa2291d0d to your computer and use it in GitHub Desktop.
Save erikcs/e5cb80f657d960df4d3279efa2291d0d to your computer and use it in GitHub Desktop.
Hybrid_R_splitCompare1.R
### compare with other ways of building a depth 3 tree
library(policytree)
rm(list = ls())
# set.seed(20)
n = 1000
p = 2
d = 3
X = round(matrix(rnorm(n*p),n,p), 2)
Y = matrix(rnorm(n*d), n, d)
htree = hybrid_policy_tree(X,Y,depth = 3, search.depth = 2) # suggested approach
# 1): Make a depth 2 tree, fit a new depth 2 tree in each 4 leaf nodes, then "prune" that
# back up to level 3. ("prune": only care about the first split, which takes us to level 3...,
# so we use the first split in this depth 2 tree to get the best action in each child.)
root = policy_tree(X, Y, depth = 2)
leaf = predict(root, X, type = "node.id")
sample.by.leaf = split(seq_len(n), leaf)
pp1 = rep(NA, n)
for (s in sample.by.leaf) {
pt = policy_tree(X[s,,drop=F], Y[s,,drop=F],depth=2)
svar=pt$nodes[[1]]$split_variable
sval=pt$nodes[[1]]$split_value
left = s [X[s, svar] <= sval]
right = s [X[s, svar] > sval]
left.action = which.max(colSums(Y[left,,drop=F]))
right.action = which.max(colSums(Y[right,,drop=F]))
for (i in s) {
if (i %in% left) {
pp1[i] = left.action
} else {
pp1[i] = right.action
}
}
}
mean(Y[cbind(1:n, pp1)]) # 1)
mean(Y[cbind(1:n, predict(root, X))]) #compared to the root tree 1) may yield a very small improvement
mean(Y[cbind(1:n, predict(htree, X))]) # This is always better, never worse
# repeating the above 500 times, the ratio mean(Y[cbind(1:n, predict(htree, X))])/mean(Y[cbind(1:n, pp1)]) :
# Min. 1st Qu. Median Mean 3rd Qu. Max.
# 1.114 1.279 1.348 1.372 1.446 2.324
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment