Last active
August 14, 2021 05:10
-
-
Save erikcs/e5cb80f657d960df4d3279efa2291d0d to your computer and use it in GitHub Desktop.
Hybrid_R_splitCompare1.R
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### compare with other ways of building a depth 3 tree | |
library(policytree) | |
rm(list = ls()) | |
# set.seed(20) | |
n = 1000 | |
p = 2 | |
d = 3 | |
X = round(matrix(rnorm(n*p),n,p), 2) | |
Y = matrix(rnorm(n*d), n, d) | |
htree = hybrid_policy_tree(X,Y,depth = 3, search.depth = 2) # suggested approach | |
# 1): Make a depth 2 tree, fit a new depth 2 tree in each 4 leaf nodes, then "prune" that | |
# back up to level 3. ("prune": only care about the first split, which takes us to level 3..., | |
# so we use the first split in this depth 2 tree to get the best action in each child.) | |
root = policy_tree(X, Y, depth = 2) | |
leaf = predict(root, X, type = "node.id") | |
sample.by.leaf = split(seq_len(n), leaf) | |
pp1 = rep(NA, n) | |
for (s in sample.by.leaf) { | |
pt = policy_tree(X[s,,drop=F], Y[s,,drop=F],depth=2) | |
svar=pt$nodes[[1]]$split_variable | |
sval=pt$nodes[[1]]$split_value | |
left = s [X[s, svar] <= sval] | |
right = s [X[s, svar] > sval] | |
left.action = which.max(colSums(Y[left,,drop=F])) | |
right.action = which.max(colSums(Y[right,,drop=F])) | |
for (i in s) { | |
if (i %in% left) { | |
pp1[i] = left.action | |
} else { | |
pp1[i] = right.action | |
} | |
} | |
} | |
mean(Y[cbind(1:n, pp1)]) # 1) | |
mean(Y[cbind(1:n, predict(root, X))]) #compared to the root tree 1) may yield a very small improvement | |
mean(Y[cbind(1:n, predict(htree, X))]) # This is always better, never worse | |
# repeating the above 500 times, the ratio mean(Y[cbind(1:n, predict(htree, X))])/mean(Y[cbind(1:n, pp1)]) : | |
# Min. 1st Qu. Median Mean 3rd Qu. Max. | |
# 1.114 1.279 1.348 1.372 1.446 2.324 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment