Skip to content

Instantly share code, notes, and snippets.

@otaviomguerra
Last active April 8, 2020 14:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save otaviomguerra/a793db5ed97c1c63900fb9b8e19b84d7 to your computer and use it in GitHub Desktop.
Save otaviomguerra/a793db5ed97c1c63900fb9b8e19b84d7 to your computer and use it in GitHub Desktop.
Plots the confusion matrix as a seaborn heatmap
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
def plot_cm(y_true, y_pred, figsize=(10,10),
class_labels=['class1', 'class2'],
suffix_label=None):
cm = confusion_matrix(y_true, y_pred, labels=np.unique(y_true))
cm_sum = np.sum(cm, axis=1, keepdims=True)
cm_perc = cm / cm_sum.astype(float) * 100
annot = np.empty_like(cm).astype(str)
nrows, ncols = cm.shape
for i in range(nrows):
for j in range(ncols):
c = cm[i, j]
p = cm_perc[i, j]
if i == j:
s = cm_sum[i]
annot[i, j] = '%.1f%%\n%d/%d' % (p, c, s)
elif c == 0:
annot[i, j] = ''
else:
annot[i, j] = '%.1f%%\n%d' % (p, c)
cm = pd.DataFrame(cm, index=np.unique(y_true), columns=np.unique(y_true))
cm.index.name = 'Actual'
cm.columns.name = 'Predicted'
sns.heatmap(cm, cmap=sns.light_palette("navy", 12),
annot=annot, fmt='', cbar=False,
xticklabels=class_labels,
yticklabels=class_labels,
annot_kws={"fontsize":14})
if suffix_label:
plt.title(f"Confusion Matrix - {suffix_label}", size=20, weight='bold')
else:
plt.title(f"Confusion Matrix", size=20, weight='bold')
plt.ylabel(cm.index.name, size=16, weight='bold')
plt.xlabel(cm.columns.name, size=16, weight='bold')
plt.xticks(size=14)
plt.yticks(size=14)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment