Skip to content

Instantly share code, notes, and snippets.

@jcboyd
Last active September 19, 2018 15:02
Show Gist options
  • Save jcboyd/2d4427b2b5ffa464da2d599d217d0dd9 to your computer and use it in GitHub Desktop.
Save jcboyd/2d4427b2b5ffa464da2d599d217d0dd9 to your computer and use it in GitHub Desktop.
confusion matrix
import numpy as np
import matplotlib.pyplot as plt
def plot_confusion_matrix(
ax, matrix, labels, title='Confusion matrix', fontsize=9):
ax.set_xticks([x for x in range(len(labels))])
ax.set_yticks([y for y in range(len(labels))])
# Place labels on minor ticks
ax.set_xticks([x + 0.5 for x in range(len(labels))], minor=True)
ax.set_xticklabels(labels, rotation='90', fontsize=fontsize, minor=True)
ax.set_yticks([y + 0.5 for y in range(len(labels))], minor=True)
ax.set_yticklabels(labels[::-1], fontsize=fontsize, minor=True)
# Hide major tick labels
ax.tick_params(which='major', labelbottom='off', labelleft='off')
# Finally, hide minor tick marks
ax.tick_params(which='minor', width=0)
# Plot heat map
proportions = [1. * row / sum(row) for row in matrix]
ax.pcolor(np.array(proportions[::-1]), cmap=plt.cm.Reds)
# Plot counts as text
for row in range(len(matrix)):
for col in range(len(matrix[row])):
confusion = matrix[::-1][row][col]
if confusion != 0:
ax.text(col + 0.5, row + 0.5, int(confusion),
fontsize=fontsize,
horizontalalignment='center',
verticalalignment='center')
# Add finishing touches
ax.grid(True, linestyle=':')
ax.set_title(title, fontsize=fontsize)
ax.set_xlabel('prediction', fontsize=fontsize)
ax.set_ylabel('actual', fontsize=fontsize)
plt.show()
if __name__ == '__main__':
matrix = np.random.randint(0, 9, (10, 10))
labels = ['airplane','automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
fig, ax = plt.subplots(figsize=(6, 6))
plot_confusion_matrix(ax, matrix, labels, fontsize=10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment