Last active
July 8, 2022 07:28
-
-
Save spider-man-tm/ab82733b0a1ae3f2f122a02c4234b4c8 to your computer and use it in GitHub Desktop.
Confusion Matrix を美しくプロットする関数
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
from mpl_toolkits.axes_grid1 import make_axes_locatable | |
from sklearn.metrics import confusion_matrix | |
from sklearn.utils.multiclass import unique_labels | |
def plot_confusion_matrix(y_true, y_pred, | |
save_path, | |
normalize=True, | |
title=None, | |
cmap=plt.cm.YlOrRd): | |
if not title: | |
if normalize: | |
title = 'Normalized confusion matrix' | |
else: | |
title = 'Confusion matrix, without normalization' | |
# Compute confusion matrix | |
cm = confusion_matrix(y_true, y_pred) | |
# Only use the labels that appear in the data | |
classes = unique_labels(y_true, y_pred).tolist() | |
if normalize: | |
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] | |
# print("Normalized confusion matrix") | |
# else: | |
# print('Confusion matrix, without normalization') | |
# print(cm) | |
fig, ax = plt.subplots(figsize=(7, 7)) | |
im = ax.imshow(cm, interpolation='nearest', cmap=cmap) | |
tick_marks = np.arange(len(classes)) | |
plt.xticks(tick_marks, fontsize=12) | |
plt.yticks(tick_marks, fontsize=12) | |
plt.xlabel('Predicted label', fontsize=12) | |
plt.ylabel('True label', fontsize=12) | |
plt.title(title, fontsize=15) | |
divider = make_axes_locatable(ax) | |
cax = divider.append_axes('right', size="5%", pad=0.15) | |
cbar = ax.figure.colorbar(im, ax=ax, cax=cax) | |
cbar.ax.tick_params(labelsize=10) | |
# We want to show all ticks... | |
ax.set(xticks=np.arange(cm.shape[1]), | |
yticks=np.arange(cm.shape[0]), | |
# ... and label them with the respective list entries | |
xticklabels=classes, yticklabels=classes, | |
# title=title, | |
ylabel='True label', | |
xlabel='Predicted label') | |
# Rotate the tick labels and set their alignment. | |
plt.setp(ax.get_xticklabels(), ha="right", | |
rotation_mode="anchor") | |
# Loop over data dimensions and create text annotations. | |
fmt = '.2f' if normalize else 'd' | |
thresh = cm.max() / 2. | |
for i in range(cm.shape[0]): | |
for j in range(cm.shape[1]): | |
ax.text(j, i, format(cm[i, j], fmt), | |
fontsize=10, | |
ha="center", va="center", | |
color="white" if cm[i, j] > thresh else "black") | |
fig.tight_layout() | |
plt.savefig(save_path) | |
plt.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment