Skip to content

Instantly share code, notes, and snippets.

@erikcs
Created August 14, 2021 04:42
Show Gist options
  • Save erikcs/ad75f9ef2c4196f7e61127904ee2f54d to your computer and use it in GitHub Desktop.
Save erikcs/ad75f9ef2c4196f7e61127904ee2f54d to your computer and use it in GitHub Desktop.
Hybrid_R_splitCompare2.R
### compare with other ways of building a depth 3 tree
library(policytree)
rm(list = ls())
# set.seed(20)
n = 1000
p = 2
d = 3
X = round(matrix(rnorm(n*p),n,p), 2)
Y = matrix(rnorm(n*d), n, d)
htree = hybrid_policy_tree(X,Y,depth = 3, search.depth = 2) # suggested approach
# 2): Make a depth 2 tree to form the first split. Make a depth 2 tree in the two immediate leaf nodes,
# then make a depth 2 tree in each 4 terminal nodes, then "prune" that back up to depth 3.
root = policy_tree(X, Y, depth = 2)
svar=root$nodes[[1]]$split_variable
sval=root$nodes[[1]]$split_value
leaf = ifelse(X[,svar] <= svar, 1, 2)
sample.by.leaf = split(seq_len(n), leaf)
pp2 = rep(NA, n)
for (S in sample.by.leaf) {
subroot = policy_tree(X[S,,drop=F], Y[S,,drop=F],depth=2)
var=subroot$nodes[[1]]$split_variable
val=subroot$nodes[[1]]$split_value
subleft = S[X[S, var] <= val]
subright = S[X[S, var] > val]
subs = list(left=subleft, right=subright)
for (s in subs) {
pt = policy_tree(X[s,,drop=F], Y[s,,drop=F],depth=2)
svar=pt$nodes[[1]]$split_variable
sval=pt$nodes[[1]]$split_value
left = s [X[s, svar] <= sval]
right = s [X[s, svar] > sval]
left.action = which.max(colSums(Y[left,]))
right.action = which.max(colSums(Y[right,]))
for (i in s) {
if (i %in% left) {
pp2[i] = left.action
} else {
pp2[i] = right.action
}
}
}
}
mean(Y[cbind(1:n, pp2)]) # 2)
mean(Y[cbind(1:n, predict(htree, X))]) # This is better
# the ratio mean(Y[cbind(1:n, predict(htree, X))]) / mean(Y[cbind(1:n, pp2)]) from 500 repetitions:
# Min. 1st Qu. Median Mean 3rd Qu. Max.
# 1.098 1.530 1.787 1.951 2.151 17.655
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment