Skip to content

Instantly share code, notes, and snippets.

@mayhewsw
Created April 22, 2016 02:18
Show Gist options
  • Save mayhewsw/2f5f8697cfd54c201d90a080da5e921f to your computer and use it in GitHub Desktop.
Save mayhewsw/2f5f8697cfd54c201d90a080da5e921f to your computer and use it in GitHub Desktop.
Plot a confusion matrix.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.patches as patches
mpl.rcParams['font.family'] = "Times New Roman"
#mpl.rcParams['font.size'] = "11"
mpl.rc('pdf', fonttype=42)
cm = [['en', '0.842', '0.508', '0.425', '0.471', '0.316', '0.252', '0.166', '0.291'], ['es', '0.628', '0.662', '0.349', '0.445', '0.29', '0.125', '0.019', '0.36'], ['de', '0.683', '0.491', '0.505', '0.455', '0.312', '0.353', '0.265', '0.289'], ['nl', '0.666', '0.528', '0.407', '0.563', '0.308', '0.151', '0.047', '0.299'], ['tr', '0.657', '0.518', '0.341', '0.411', '0.472', '0.231', '0.118', '0.358'], ['uz', '0.565', '0.446', '0.367', '0.391', '0.332', '0.481', '0.305', '0.265'], ['bn', '0.604', '0.442', '0.236', '0.316', '0.314', '0.418', '0.363', '0.309'], ['ha', '0.583', '0.358', '0.255', '0.29', '0.266', '0.123', '0.059', '0.403']]
langs = map(lambda l: l[0], cm)
cm = map(lambda l: map(lambda n: 100*float(n), l[1:]), cm)
def plot_confusion_matrix(cm, title='Confusion matrix'):
plt.imshow(cm, interpolation='nearest', cmap=plt.get_cmap("YlGn"))
#plt.colorbar()
tick_marks = np.arange(len(langs))
plt.xticks(tick_marks, langs, rotation='vertical')
plt.yticks(tick_marks, langs)
plt.tick_params(length=0)
plt.ylabel('Train', fontsize=18)
ax = plt.gca()
ax.xaxis.tick_top()
ax.set_xlabel('Test', fontsize=18)
ax.xaxis.set_label_position('top')
npcm = np.array(cm)
bests = []
nextbests = []
for i,col in enumerate(npcm.T):
best = np.argsort(col)[-1]
nextbest = np.argsort(col)[-2]
bests.append((i,best))
nextbests.append((i,nextbest))
for x in xrange(len(cm)):
for y in xrange(len(cm)):
plt.annotate(str(cm[x][y]), xy=(y, x),
horizontalalignment='center',
verticalalignment='center')
# if max in column, also draw rect.
# if (x,y) in bests:
# ax.add_patch(
# patches.Rectangle(
# (x-0.5, y-0.5),
# 1,
# 1,
# fill=False,
# edgecolor="red",
# linewidth=3
# )
# )
if (x,y) in nextbests:
ax.add_patch(
patches.Rectangle(
(x-0.5, y-0.5),
1,
1,
fill=False,
edgecolor="red",
linewidth=4
)
)
# Compute confusion matrix
np.set_printoptions(precision=2)
plt.figure()
plot_confusion_matrix(cm, title='Confusion matrix')
plt.savefig("confmat.pdf", format='pdf')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment