Skip to content

Instantly share code, notes, and snippets.

@loretoparisi
Last active November 4, 2022 13:21
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 5 You must be signed in to fork a gist
  • Save loretoparisi/41b918add11893d761d0ec12a3a4e1aa to your computer and use it in GitHub Desktop.
Save loretoparisi/41b918add11893d761d0ec12a3a4e1aa to your computer and use it in GitHub Desktop.
Calculate FastText Classifier Confusion Matrix
#!/usr/local/bin/python3
# @author cpuhrsch https://github.com/cpuhrsch
# @author Loreto Parisi loreto@musixmatch.com
import argparse
import numpy as np
from sklearn.metrics import confusion_matrix
def parse_labels(path):
with open(path, 'r') as f:
return np.array(list(map(lambda x: x[9:], f.read().split())))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Display confusion matrix.')
parser.add_argument('test', help='Path to test labels')
parser.add_argument('predict', help='Path to predictions')
args = parser.parse_args()
test_labels = parse_labels(args.test)
pred_labels = parse_labels(args.predict)
eq = test_labels == pred_labels
print("Accuracy: " + str(eq.sum() / len(test_labels)))
print(confusion_matrix(test_labels, pred_labels))
@loretoparisi
Copy link
Author

loretoparisi commented Oct 31, 2017

Usage example:
Supposed that our test set must be normalized since it has string labels with no prefix (while FastText has a __label__ default prefix:

DATASET=$1
MODEL=$2
ROOT=/root
echo Normalizing dataset $DATASET...
awk 'BEGIN{FS=OFS="\t"}{ $1 = "__label__" tolower($1) }1' $DATASET > $ROOT/norm
cut -f 1 -d$'\t' $ROOT/norm > $ROOT/normlabels

echo Calculating predictions...
fasttext predict $MODEL $DATASET > $ROOT/pexp
./confusion.py $ROOT/normlabels $ROOT/pexp

and you get

$ ./confusion.sh dataset_test.csv model.bin 
Accuracy: 0.998852852227
[[5003    21]
 [    6 14008]]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment