Skip to content

Instantly share code, notes, and snippets.

@terrycojones
Created March 9, 2015 15:54
Show Gist options
  • Save terrycojones/220e9bc1bf4fb4d4170d to your computer and use it in GitHub Desktop.
Save terrycojones/220e9bc1bf4fb4d4170d to your computer and use it in GitHub Desktop.
Confusion matrix
from collections import defaultdict
def confusion(trueLabels, clusterLabels):
counts = defaultdict(lambda: defaultdict(int))
allLabels = sorted(set(trueLabels + clusterLabels))
for trueLabel, clusterLabel in zip(trueLabels, clusterLabels):
counts[trueLabel][clusterLabel] += 1
return allLabels, counts
@terrycojones
Copy link
Author

Usage (with ugly printing) looks something like this

labels, counts = confusion([0, 0, 1], [1, 2, 1])

print '  ',
for label in labels:
    print label,
print

for trueLabel in labels:
    print '%s: ' % trueLabel,
    for clusterLabel in labels:
        print counts[trueLabel][clusterLabel],
    print

Producing

0 1 2
0:  0 1 1
1:  0 1 0
2:  0 0 0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment