Skip to content

Instantly share code, notes, and snippets.

@zafercavdar
Last active August 17, 2021 11:20
Show Gist options
  • Save zafercavdar/c8be4ce3758fdca4fb60a1672fb4f451 to your computer and use it in GitHub Desktop.
Save zafercavdar/c8be4ce3758fdca4fb60a1672fb4f451 to your computer and use it in GitHub Desktop.
from sklearn.tree import export_graphviz
export_graphviz(best_model,
out_file='tree.dot',
feature_names = feature_pd.columns,
class_names = le.classes_,
rounded = True,
proportion = False,
precision = 2,
filled = True)
with open("tree.dot", "r") as f:
data = f.read()
import re
from copy import deepcopy
pattern = re.compile(r"\\nclass = ([\s\w]+)\"")
pattern2 = re.compile(r"(\")([\w\d\s\-<=|.]+\\n)(entropy)")
class Node:
def __init__(self, _id: str, _class: str, def_str):
self._id = _id
self.children = []
self._class = _class
self.def_str = def_str
def add_child(self, subnode):
self.children.append(subnode)
@property
def is_leaf(self):
return len(self.children) == 0
def prune_if_needed(self):
child_classes = set([child._class for child in self.children])
if len(child_classes) == 1 and all([child.is_leaf for child in self.children]):
print(f"Pruning {self._id}")
killed_children = deepcopy(self.children)
self.children = []
self.def_str = pattern2.sub("\\1\\3", self.def_str)
# print(self.def_str)
return killed_children
return []
def prune(graph_data):
nodes = {}
lines = graph_data.split("\n")
for line in lines:
if " -> " in line:
_from, _to = line.split(" -> ")
_to = _to.split(" ")[0]
nodes[_from].add_child(nodes[_to])
elif any([line.startswith(x) for x in ["digraph", "node", "edge", "}"]]):
continue
else:
parts = line.split()
node_id = parts[0]
rest = " ".join(parts[1:])
_class = pattern.findall(rest)[0]
new_node = Node(node_id, _class, line)
nodes[node_id] = new_node
deleted_nodes = []
for _ in range(10):
for node_id, node in nodes.items():
deleted_nodes.extend(node.prune_if_needed())
deleted_node_ids = [node._id for node in deleted_nodes]
new_lines = []
for line in lines:
if " -> " in line:
_from, _to = line.split(" -> ")
_to = _to.split(" ")[0]
if _to not in deleted_node_ids:
new_lines.append(line)
elif any([line.startswith(x) for x in ["digraph", "node", "edge", "}"]]):
new_lines.append(line)
else:
parts = line.split()
node_id = parts[0]
if node_id not in deleted_node_ids:
new_lines.append(" ".join([nodes[node_id]._id, nodes[node_id].def_str]))
return "\n".join(new_lines)
with open("tree2.dot", "w") as f:
# this line is optional
updated_data = data.replace("<= 0.5", "is AOM-associated").replace("True", "ĞĞĞ").replace("False", "True").replace("ĞĞĞ", "False")
f.write(prune(updated_data))
!dot -Tpng tree.dot -o tree.png -Gdpi=300
from IPython.display import Image
Image(filename = 'tree.png')
!dot -Tpng tree2.dot -o tree.png -Gdpi=300
from IPython.display import Image
Image(filename = 'tree.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment