Created
February 16, 2021 09:13
-
-
Save j-adamczyk/6975bbb05fb0c986033b1f31742b3aff to your computer and use it in GitHub Desktop.
Plotting decision trees with Graphviz in RAM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
import cv2 | |
import graphviz | |
import numpy as np | |
from sklearn.tree import DecisionTreeClassifier, export_graphviz | |
def plot_in_memory(clf: DecisionTreeClassifier, | |
feature_names: List[str], | |
class_names: List[str]) -> np.ndarray: | |
dot_tree = export_graphviz(decision_tree=clf, out_file=None, | |
feature_names=feature_names, | |
class_names=class_names, | |
label="all", filled=True, impurity=False, | |
proportion=True, rounded=True, precision=2) | |
# get tree plot in memory as bytes of image | |
image_bytes = graphviz.Source(dot_tree).pipe(format="png") | |
# decode bytes as image; OpenCV uses BGR, convert to RGB for right display | |
image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), -1) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
return image |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment