Last active
March 12, 2016 14:07
-
-
Save tanyuan/1a7271991d3774995e19 to your computer and use it in GitHub Desktop.
Plot confusion matrix in grey scale.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
============================== | |
EdgeVib: Plot Confusion matrix | |
============================== | |
Author: tanyuan | |
""" | |
import argparse | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.cm as cm | |
import matplotlib.colors as colors | |
import matplotlib.patches as patches | |
from sklearn.metrics import confusion_matrix | |
LABEL_SIZE = 14 | |
VALUE_SIZE = 12 | |
LABEL_ROTATION = 0 | |
# Plot the matrix: show and save as image | |
def plotMatrix(cm, labels, title, fname): | |
print("> Plot confusion matrix to", fname, "...") | |
fig = plt.figure() | |
# COLORS | |
# ====== | |
# Discrete color map: make a color map of fixed colors | |
# [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1] | |
color_list = [] | |
for i in range(10): | |
i += 1 | |
color_list.append((i*0.1, i*0.1, i*0.1)) | |
# Reversed gray scale (black for 100, white for 0) | |
color_map = colors.ListedColormap(list(reversed(color_list))) | |
# Set color bounds: [0,10,20,30,40,50,60,70,80,90,100] | |
bounds=range(0, 110, 10) | |
norm = colors.BoundaryNorm(bounds, color_map.N) | |
# Plot matrix (convert numpy to list) and set Z max to 100 | |
plt.imshow(cm, interpolation='nearest', cmap=color_map, vmax=100) | |
plt.title(title) | |
plt.colorbar() | |
# LABELS | |
# ====== | |
# Setup labels (same for both axises) | |
tick_marks = np.arange(len(labels)) | |
plt.xticks(tick_marks, labels, fontsize=LABEL_SIZE, rotation=LABEL_ROTATION) | |
plt.yticks(tick_marks, labels, fontsize=LABEL_SIZE) | |
plt.ylabel('True labels') | |
plt.xlabel('Predicted labels') | |
# VALUES | |
# ====== | |
# Add value text on the plot | |
ax = fig.add_subplot(1, 1, 1) | |
min_val, max_val, diff = 0., len(labels), 1. | |
ind_array = np.arange(min_val, max_val, diff) | |
x, y = np.meshgrid(ind_array, ind_array) | |
# Display values on the correct position | |
for x_val, y_val in zip(x.flatten(), y.flatten()): | |
# Round the float numbers | |
value = int(round(cm[y_val][x_val])) | |
# Only show values that are not 0 | |
if value != 0: | |
# Draw boxes | |
ax.add_patch( | |
patches.Rectangle( | |
(x_val-0.5, y_val-0.5), # (x,y) | |
1, # width | |
1, # height | |
fill=None, | |
edgecolor=(0.8, 0.8, 0.8), | |
) | |
) | |
# Show lighter color for dark background | |
if value > 50: | |
ax.text(x_val, y_val, value, va='center', ha='center', fontsize=VALUE_SIZE, color=(1, 1, 1)) | |
elif value >= 10: | |
ax.text(x_val, y_val, value, va='center', ha='center', fontsize=VALUE_SIZE, color=(0, 0, 0)) | |
# Hide the little ticks on the axis | |
for tic in ax.xaxis.get_major_ticks(): | |
tic.tick1On = tic.tick2On = False | |
for tic in ax.yaxis.get_major_ticks(): | |
tic.tick1On = tic.tick2On = False | |
# Save as an image | |
# Higher DPI to avoid mis-alignment | |
plt.savefig(fname, dpi=240) | |
# Show the plot in a window | |
plt.show() | |
if __name__ == '__main__' : | |
parser = argparse.ArgumentParser(description='Plot confusion matrix from EdgeVib experiment.') | |
parser.add_argument('data_labels', help='Labels of data, used for axis ticks') | |
parser.add_argument('data_true', help='True label of data') | |
parser.add_argument('data_pred', help='Predicted label of data') | |
parser.add_argument('title', help='Plot title') | |
parser.add_argument('output', help='Output plot filename') | |
args = parser.parse_args() | |
data_labels = [] | |
data_true = [] | |
data_pred = [] | |
# Read from files | |
with open(args.data_labels) as f: | |
data_labels = f.read().splitlines() | |
with open(args.data_true) as f: | |
data_true = f.read().splitlines() | |
with open(args.data_pred) as f: | |
data_pred = f.read().splitlines() | |
# Calculate confusion matrix | |
cm = confusion_matrix(data_true, data_pred) | |
print(cm) | |
# Transform to percentage | |
for row in range(0, len(cm)): | |
rowSum = np.sum(cm[row]) | |
cm[row] = cm[row] / rowSum * 100 | |
print(cm) | |
# Plot the matrix and save as image | |
plotMatrix(cm, data_labels, args.title, args.output) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment