Skip to content

Instantly share code, notes, and snippets.

@halflearned
Created September 5, 2018 04:25
Show Gist options
  • Save halflearned/66d102d25bdaad3c3b8bd18808eeac2e to your computer and use it in GitHub Desktop.
Save halflearned/66d102d25bdaad3c3b8bd18808eeac2e to your computer and use it in GitHub Desktop.
Indexing bug
library(grf)
# Some forest
n <- 1000
k <- 3
X <- matrix(runif(n*k), nrow=n, ncol=k)
Y <- matrix(runif(n), nrow=n, ncol=1)
forest <- grf::regression_forest(X, Y)
# Get tree
tree <- get_tree(forest, 1)
leaf_nodes <- Filter(f = function(x) x$is_leaf, tree$nodes)
# This should contain all in-bag data
estimation_and_split_sample <- tree$drawn_samples
# This is the estimation sample. It should be contained in the vector above
estimation_sample <- unlist(Map(f=function(x) x$samples, leaf_nodes))
# This shouldn't contain anything...
should_be_empty <- setdiff(estimation_sample, estimation_and_split_sample)
# ...but it does!
cat(c("Length (bad):", length(should_be_empty), "\n"))
# The problem is that drawn_samples are not 1-indexed
now_this_is_empty <- setdiff(estimation_sample, estimation_and_split_sample+1)
# Now we're good!
cat(c("Length (good):", length(now_this_is_empty), "\n"))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment