Skip to content

Instantly share code, notes, and snippets.

@randyphoa
Last active August 6, 2022 07:16
Show Gist options
  • Save randyphoa/978e29965a11402b21377f7d61ad5b54 to your computer and use it in GitHub Desktop.
Save randyphoa/978e29965a11402b21377f7d61ad5b54 to your computer and use it in GitHub Desktop.
Decision Tree Rules
import sklearn
import sklearn.tree
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder
def get_rules(tree, feature_names, class_names):
tree_ = tree.tree_
feature_name = [feature_names[i] if i != sklearn.tree._tree.TREE_UNDEFINED else "undefined!" for i in tree_.feature]
paths = []
path = []
def recurse(node, path, paths):
if tree_.feature[node] != sklearn.tree._tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
p1, p2 = list(path), list(path)
p1 += [f"({name} <= {np.round(threshold, 6)})"]
recurse(tree_.children_left[node], p1, paths)
p2 += [f"({name} > {np.round(threshold, 6)})"]
recurse(tree_.children_right[node], p2, paths)
else:
path += [(tree_.value[node], tree_.n_node_samples[node])]
paths += [path]
recurse(0, path, paths)
# sort by samples count
samples_count = [p[-1][1] for p in paths]
ii = list(np.argsort(samples_count))
paths = [paths[i] for i in reversed(ii)]
rules = []
for path in paths:
rule = "if "
for p in path[:-1]:
if rule != "if ":
rule += " and "
rule += str(p)
rule += " then "
if class_names is None:
rule += "response: " + str(np.round(path[-1][0][0][0], 3))
else:
classes = path[-1][0][0]
l = np.argmax(classes)
rule += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
rule += f" | based on {path[-1][1]:,} samples"
rules += [rule]
return rules
cols_train = ["HIGH_LOW_DIFF", "HIGH_LOW_DIFF_M1", "HIGH_LOW_DIFF_M2"]
y = test["TARGET"]
X = test[cols_train]
get_rules(tree=sklearn.tree.DecisionTreeClassifier(random_state=12345, max_depth=2).fit(X=X, y=y), feature_names=X.columns.tolist(), class_names=[0, 1])
pipeline = Pipeline(steps=[("ohe", OneHotEncoder(handle_unknown="ignore", drop="first")), ("classifier", lgb.LGBMClassifier(objective="binary")),])
pipeline.fit(X, y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment