Skip to content

Instantly share code, notes, and snippets.

@christian-rauch
Last active March 8, 2017 12:38
Show Gist options
  • Save christian-rauch/7ce87819347acfd9168ef08822a04559 to your computer and use it in GitHub Desktop.
Save christian-rauch/7ce87819347acfd9168ef08822a04559 to your computer and use it in GitHub Desktop.
Prune (cut) scikit's RandomForestClassifier at a given depth
# This will insert leaf nodes at a given depth of the tree, e.g. the decision path will end at this depth.
# It does not actually remove the nodes from the list in 'children_left' and 'children_right',
# e.g. the split nodes will stay in the tree but will not be used within a decision path.
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree.tree import Tree
def prune_node(tree, node_id, parent_depth, prune_depth):
this_depth = parent_depth+1 # root node at depth 1
if (tree.children_left[node_id]==-1 and tree.children_right[node_id]==-1) and (tree.feature[node_id]==-2 and tree.threshold[node_id]==-2):
# we are at leaf node
return
else:
# we are at split node
if this_depth > prune_depth:
# cut here, e.g. make leaf node
tree.children_left[node_id] = -1
tree.children_right[node_id] = -1
tree.feature[node_id] = -2
tree.threshold[node_id] = -2
return
else:
# continue search
prune_node(tree, tree.children_left[node_id], this_depth, prune_depth)
prune_node(tree, tree.children_right[node_id], this_depth, prune_depth)
return
# prune trees, e.g. insert leaf nodes at prune_depth+1
# the 'value' (class histogram) at the node will stay the same and
def prune_tree(tree, prune_depth):
current_depth = 0
prune_node(tree, 0, current_depth, prune_depth)
# cut each tree of forest after 'prune_depth'
# this will directly change the provided forest
def prune_forest(forest, prune_depth):
for estimator in forest.estimators_:
prune_tree(estimator.tree_, prune_depth)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment