Skip to content

Instantly share code, notes, and snippets.

@tanyuan
Last active March 12, 2016 14:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tanyuan/1a7271991d3774995e19 to your computer and use it in GitHub Desktop.
Save tanyuan/1a7271991d3774995e19 to your computer and use it in GitHub Desktop.
Plot confusion matrix in grey scale.
"""
==============================
EdgeVib: Plot Confusion matrix
==============================
Author: tanyuan
"""
import argparse
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as colors
import matplotlib.patches as patches
from sklearn.metrics import confusion_matrix
LABEL_SIZE = 14
VALUE_SIZE = 12
LABEL_ROTATION = 0
# Plot the matrix: show and save as image
def plotMatrix(cm, labels, title, fname):
print("> Plot confusion matrix to", fname, "...")
fig = plt.figure()
# COLORS
# ======
# Discrete color map: make a color map of fixed colors
# [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
color_list = []
for i in range(10):
i += 1
color_list.append((i*0.1, i*0.1, i*0.1))
# Reversed gray scale (black for 100, white for 0)
color_map = colors.ListedColormap(list(reversed(color_list)))
# Set color bounds: [0,10,20,30,40,50,60,70,80,90,100]
bounds=range(0, 110, 10)
norm = colors.BoundaryNorm(bounds, color_map.N)
# Plot matrix (convert numpy to list) and set Z max to 100
plt.imshow(cm, interpolation='nearest', cmap=color_map, vmax=100)
plt.title(title)
plt.colorbar()
# LABELS
# ======
# Setup labels (same for both axises)
tick_marks = np.arange(len(labels))
plt.xticks(tick_marks, labels, fontsize=LABEL_SIZE, rotation=LABEL_ROTATION)
plt.yticks(tick_marks, labels, fontsize=LABEL_SIZE)
plt.ylabel('True labels')
plt.xlabel('Predicted labels')
# VALUES
# ======
# Add value text on the plot
ax = fig.add_subplot(1, 1, 1)
min_val, max_val, diff = 0., len(labels), 1.
ind_array = np.arange(min_val, max_val, diff)
x, y = np.meshgrid(ind_array, ind_array)
# Display values on the correct position
for x_val, y_val in zip(x.flatten(), y.flatten()):
# Round the float numbers
value = int(round(cm[y_val][x_val]))
# Only show values that are not 0
if value != 0:
# Draw boxes
ax.add_patch(
patches.Rectangle(
(x_val-0.5, y_val-0.5), # (x,y)
1, # width
1, # height
fill=None,
edgecolor=(0.8, 0.8, 0.8),
)
)
# Show lighter color for dark background
if value > 50:
ax.text(x_val, y_val, value, va='center', ha='center', fontsize=VALUE_SIZE, color=(1, 1, 1))
elif value >= 10:
ax.text(x_val, y_val, value, va='center', ha='center', fontsize=VALUE_SIZE, color=(0, 0, 0))
# Hide the little ticks on the axis
for tic in ax.xaxis.get_major_ticks():
tic.tick1On = tic.tick2On = False
for tic in ax.yaxis.get_major_ticks():
tic.tick1On = tic.tick2On = False
# Save as an image
# Higher DPI to avoid mis-alignment
plt.savefig(fname, dpi=240)
# Show the plot in a window
plt.show()
if __name__ == '__main__' :
parser = argparse.ArgumentParser(description='Plot confusion matrix from EdgeVib experiment.')
parser.add_argument('data_labels', help='Labels of data, used for axis ticks')
parser.add_argument('data_true', help='True label of data')
parser.add_argument('data_pred', help='Predicted label of data')
parser.add_argument('title', help='Plot title')
parser.add_argument('output', help='Output plot filename')
args = parser.parse_args()
data_labels = []
data_true = []
data_pred = []
# Read from files
with open(args.data_labels) as f:
data_labels = f.read().splitlines()
with open(args.data_true) as f:
data_true = f.read().splitlines()
with open(args.data_pred) as f:
data_pred = f.read().splitlines()
# Calculate confusion matrix
cm = confusion_matrix(data_true, data_pred)
print(cm)
# Transform to percentage
for row in range(0, len(cm)):
rowSum = np.sum(cm[row])
cm[row] = cm[row] / rowSum * 100
print(cm)
# Plot the matrix and save as image
plotMatrix(cm, data_labels, args.title, args.output)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment