Last active
April 8, 2020 14:21
-
-
Save otaviomguerra/a793db5ed97c1c63900fb9b8e19b84d7 to your computer and use it in GitHub Desktop.
Plots the confusion matrix as a seaborn heatmap
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from sklearn.metrics import confusion_matrix | |
def plot_cm(y_true, y_pred, figsize=(10,10), | |
class_labels=['class1', 'class2'], | |
suffix_label=None): | |
cm = confusion_matrix(y_true, y_pred, labels=np.unique(y_true)) | |
cm_sum = np.sum(cm, axis=1, keepdims=True) | |
cm_perc = cm / cm_sum.astype(float) * 100 | |
annot = np.empty_like(cm).astype(str) | |
nrows, ncols = cm.shape | |
for i in range(nrows): | |
for j in range(ncols): | |
c = cm[i, j] | |
p = cm_perc[i, j] | |
if i == j: | |
s = cm_sum[i] | |
annot[i, j] = '%.1f%%\n%d/%d' % (p, c, s) | |
elif c == 0: | |
annot[i, j] = '' | |
else: | |
annot[i, j] = '%.1f%%\n%d' % (p, c) | |
cm = pd.DataFrame(cm, index=np.unique(y_true), columns=np.unique(y_true)) | |
cm.index.name = 'Actual' | |
cm.columns.name = 'Predicted' | |
sns.heatmap(cm, cmap=sns.light_palette("navy", 12), | |
annot=annot, fmt='', cbar=False, | |
xticklabels=class_labels, | |
yticklabels=class_labels, | |
annot_kws={"fontsize":14}) | |
if suffix_label: | |
plt.title(f"Confusion Matrix - {suffix_label}", size=20, weight='bold') | |
else: | |
plt.title(f"Confusion Matrix", size=20, weight='bold') | |
plt.ylabel(cm.index.name, size=16, weight='bold') | |
plt.xlabel(cm.columns.name, size=16, weight='bold') | |
plt.xticks(size=14) | |
plt.yticks(size=14) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment