Skip to content

Instantly share code, notes, and snippets.

@spider-man-tm
Last active July 8, 2022 07:28
Show Gist options
  • Save spider-man-tm/ab82733b0a1ae3f2f122a02c4234b4c8 to your computer and use it in GitHub Desktop.
Save spider-man-tm/ab82733b0a1ae3f2f122a02c4234b4c8 to your computer and use it in GitHub Desktop.
Confusion Matrix を美しくプロットする関数
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from sklearn.metrics import confusion_matrix
from sklearn.utils.multiclass import unique_labels
def plot_confusion_matrix(y_true, y_pred,
save_path,
normalize=True,
title=None,
cmap=plt.cm.YlOrRd):
if not title:
if normalize:
title = 'Normalized confusion matrix'
else:
title = 'Confusion matrix, without normalization'
# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred)
# Only use the labels that appear in the data
classes = unique_labels(y_true, y_pred).tolist()
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
# print("Normalized confusion matrix")
# else:
# print('Confusion matrix, without normalization')
# print(cm)
fig, ax = plt.subplots(figsize=(7, 7))
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, fontsize=12)
plt.yticks(tick_marks, fontsize=12)
plt.xlabel('Predicted label', fontsize=12)
plt.ylabel('True label', fontsize=12)
plt.title(title, fontsize=15)
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size="5%", pad=0.15)
cbar = ax.figure.colorbar(im, ax=ax, cax=cax)
cbar.ax.tick_params(labelsize=10)
# We want to show all ticks...
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
# ... and label them with the respective list entries
xticklabels=classes, yticklabels=classes,
# title=title,
ylabel='True label',
xlabel='Predicted label')
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), ha="right",
rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], fmt),
fontsize=10,
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()
plt.savefig(save_path)
plt.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment