Skip to content

Instantly share code, notes, and snippets.

@barnybug
Created November 12, 2015 15:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save barnybug/836eae1b46fca7087e80 to your computer and use it in GitHub Desktop.
Save barnybug/836eae1b46fca7087e80 to your computer and use it in GitHub Desktop.
def get_code(tree, feature_names, class_names):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value
def recurse(left, right, threshold, features, node, indent):
if threshold[node] != -2:
print "%sif %s <= %s:" % (indent, features[node], threshold[node])
if left[node] != -1:
recurse(left, right, threshold, features, left[node], indent+' ')
print '%selse:' % indent
if right[node] != -1:
recurse(left, right, threshold, features, right[node], indent+' ')
else:
cls = class_names[np.argmax(value[node])]
print '%sreturn %s' % (indent, cls)
recurse(left, right, threshold, features, 0, '')
get_code(clf, features, ['False', 'True'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment