Skip to content

Instantly share code, notes, and snippets.

@schaunwheeler
Created April 13, 2019 12:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save schaunwheeler/2d9b2328cfc7c5a91a6924c01eaf7e40 to your computer and use it in GitHub Desktop.
Save schaunwheeler/2d9b2328cfc7c5a91a6924c01eaf7e40 to your computer and use it in GitHub Desktop.
Create function in pure Python that calculates predictions from a Scikit-Learn RandomForestRegressor
from sklearn.tree import _tree
tree_template = '''
def tree{i}(inputs):
tree_undefined = {tree_undefined}
features = {features}
thresholds = {thresholds}
children_left = {children_left}
children_right = {children_right}
values = {values}
node = 0
while features[node] != tree_undefined:
feat = features[node]
threshold = thresholds[node]
if inputs[feat] <= threshold:
node = children_left[node]
else:
node = children_right[node]
output = values[node]
return output
'''
template_footer = '''
def forest(inputs):
return ({combined_trees}) / {n}
'''
template_final = ''
n = len(rfr.estimators_)
for i, model in enumerate(rfr.estimators_):
template_final += tree_template.format(
i=i,
tree_undefined=_tree.TREE_UNDEFINED,
features=repr(model.tree_.feature.tolist()),
thresholds=repr(model.tree_.threshold.tolist()),
children_left=repr(model.tree_.children_left.tolist()),
children_right=repr(model.tree_.children_right.tolist()),
values=repr([val[0][0] for val in model.tree_.value]),
)
template_final += template_footer.format(
combined_trees=' + '.join(['tree{i}(inputs)'.format(i=i) for i in range(n)]),
n=float(n)
)
# execute the constructed code to load the function `forest` in the environment
exec(template_final)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment