Skip to content

Instantly share code, notes, and snippets.

@j-adamczyk
Created February 16, 2021 09:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save j-adamczyk/6975bbb05fb0c986033b1f31742b3aff to your computer and use it in GitHub Desktop.
Save j-adamczyk/6975bbb05fb0c986033b1f31742b3aff to your computer and use it in GitHub Desktop.
Plotting decision trees with Graphviz in RAM
from typing import List
import cv2
import graphviz
import numpy as np
from sklearn.tree import DecisionTreeClassifier, export_graphviz
def plot_in_memory(clf: DecisionTreeClassifier,
feature_names: List[str],
class_names: List[str]) -> np.ndarray:
dot_tree = export_graphviz(decision_tree=clf, out_file=None,
feature_names=feature_names,
class_names=class_names,
label="all", filled=True, impurity=False,
proportion=True, rounded=True, precision=2)
# get tree plot in memory as bytes of image
image_bytes = graphviz.Source(dot_tree).pipe(format="png")
# decode bytes as image; OpenCV uses BGR, convert to RGB for right display
image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), -1)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment