Skip to content

Instantly share code, notes, and snippets.

@DavidFricker
Created September 25, 2018 11:05
Show Gist options
  • Save DavidFricker/4b7c01013306657e48a6467f875b971c to your computer and use it in GitHub Desktop.
Save DavidFricker/4b7c01013306657e48a6467f875b971c to your computer and use it in GitHub Desktop.
def classification_report(y_true, y_pred, labels):
'''Similar to the one in sklearn.metrics, reports per classs recall, precision and F1 score'''
y_true = numpy.asarray(y_true).ravel()
y_pred = numpy.asarray(y_pred).ravel()
corrects = Counter(yt for yt, yp in zip(y_true, y_pred) if yt == yp)
y_true_counts = Counter(y_true)
y_pred_counts = Counter(y_pred)
report = ((lab, # label
corrects[i] / max(1, y_true_counts[i]), # recall
corrects[i] / max(1, y_pred_counts[i]), # precision
y_true_counts[i] # support
) for i, lab in enumerate(labels))
report = [(l, r, p, 2 * r * p / max(1e-9, r + p), s) for l, r, p, s in report]
print('{:<15}{:>10}{:>10}{:>10}{:>10}\n'.format('', 'recall', 'precision', 'f1-score', 'support'))
formatter = '{:<15}{:>10.2f}{:>10.2f}{:>10.2f}{:>10d}'.format
for r in report:
print(formatter(*r))
print('')
report2 = zip(*[(r * s, p * s, f1 * s) for l, r, p, f1, s in report])
N = len(y_true)
print(formatter('avg / total', sum(report2[0]) / N, sum(report2[1]) / N, sum(report2[2]) / N, N) + '\n')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment