Created
August 14, 2021 04:42
-
-
Save erikcs/ad75f9ef2c4196f7e61127904ee2f54d to your computer and use it in GitHub Desktop.
Hybrid_R_splitCompare2.R
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### compare with other ways of building a depth 3 tree | |
library(policytree) | |
rm(list = ls()) | |
# set.seed(20) | |
n = 1000 | |
p = 2 | |
d = 3 | |
X = round(matrix(rnorm(n*p),n,p), 2) | |
Y = matrix(rnorm(n*d), n, d) | |
htree = hybrid_policy_tree(X,Y,depth = 3, search.depth = 2) # suggested approach | |
# 2): Make a depth 2 tree to form the first split. Make a depth 2 tree in the two immediate leaf nodes, | |
# then make a depth 2 tree in each 4 terminal nodes, then "prune" that back up to depth 3. | |
root = policy_tree(X, Y, depth = 2) | |
svar=root$nodes[[1]]$split_variable | |
sval=root$nodes[[1]]$split_value | |
leaf = ifelse(X[,svar] <= svar, 1, 2) | |
sample.by.leaf = split(seq_len(n), leaf) | |
pp2 = rep(NA, n) | |
for (S in sample.by.leaf) { | |
subroot = policy_tree(X[S,,drop=F], Y[S,,drop=F],depth=2) | |
var=subroot$nodes[[1]]$split_variable | |
val=subroot$nodes[[1]]$split_value | |
subleft = S[X[S, var] <= val] | |
subright = S[X[S, var] > val] | |
subs = list(left=subleft, right=subright) | |
for (s in subs) { | |
pt = policy_tree(X[s,,drop=F], Y[s,,drop=F],depth=2) | |
svar=pt$nodes[[1]]$split_variable | |
sval=pt$nodes[[1]]$split_value | |
left = s [X[s, svar] <= sval] | |
right = s [X[s, svar] > sval] | |
left.action = which.max(colSums(Y[left,])) | |
right.action = which.max(colSums(Y[right,])) | |
for (i in s) { | |
if (i %in% left) { | |
pp2[i] = left.action | |
} else { | |
pp2[i] = right.action | |
} | |
} | |
} | |
} | |
mean(Y[cbind(1:n, pp2)]) # 2) | |
mean(Y[cbind(1:n, predict(htree, X))]) # This is better | |
# the ratio mean(Y[cbind(1:n, predict(htree, X))]) / mean(Y[cbind(1:n, pp2)]) from 500 repetitions: | |
# Min. 1st Qu. Median Mean 3rd Qu. Max. | |
# 1.098 1.530 1.787 1.951 2.151 17.655 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment