Skip to content

Instantly share code, notes, and snippets.

@algal
Created September 9, 2020 20:08
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save algal/4803fc8ff0842d26967b138f1e9c4358 to your computer and use it in GitHub Desktop.
Save algal/4803fc8ff0842d26967b138f1e9c4358 to your computer and use it in GitHub Desktop.
Plot a confusion matrix in scikitlearn from data not from an estimator
# This uses scikit learn internals, since the sk public API requires you to pass
# in an estimator and sometimes you just want to pass in the some data you'd
# use to calculate a raw CM
def plot_cm(y_true,y_pred,labels):
from sklearn.metrics._plot.confusion_matrix import ConfusionMatrixDisplay
sample_weight = None
normalize = None
include_values = True
cmap='viridis'
ax = None
xticks_rotation='horizontal'
values_format = None
cm = confusion_matrix(y_true, y_pred, sample_weight=sample_weight,
labels=labels, normalize=normalize)
display_labels = labels
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
display_labels=display_labels)
return disp.plot(include_values=include_values,
cmap=cmap, ax=ax, xticks_rotation=xticks_rotation,
values_format=values_format)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment