Skip to content

Instantly share code, notes, and snippets.

@funktor
Created November 6, 2018 13:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save funktor/976dd950bbb14617d96bf38019b47bc6 to your computer and use it in GitHub Desktop.
Save funktor/976dd950bbb14617d96bf38019b47bc6 to your computer and use it in GitHub Desktop.
from collections import defaultdict
import numpy as np, random, re, math, pickle
import data_reader as dr
def get_sequence(labels):
seq_label, last = dict(), 0
for idx in range(len(labels)):
if labels[idx] != 'O':
if labels[idx][0] == 'B' or labels[idx][0] == 'S':
last = idx
seq_label[last] = idx
return seq_label
def get_classification_score(test_labels, pred_labels):
tp, fp, fn = defaultdict(float), defaultdict(float), defaultdict(float)
support = defaultdict(float)
n = len(test_labels)
for idx in range(n):
true_label, pred_label = test_labels[idx], pred_labels[idx]
true_seq, pred_seq = get_sequence(true_label), get_sequence(pred_label)
for start, end in true_seq.items():
true_tag = true_label[start][2:]
pred_tag = pred_label[start][2:]
support[true_tag] += 1
if start in pred_seq and pred_seq[start] == end and pred_tag == true_tag:
tp[true_tag] += 1
else:
fn[true_tag] += 1
for start, end in pred_seq.items():
true_tag = true_label[start][2:]
pred_tag = pred_label[start][2:]
if start not in pred_seq or pred_seq[start] != end or pred_tag != true_tag:
fp[pred_tag] += 1
precision, recall, f1_score = defaultdict(float), defaultdict(float), defaultdict(float)
tot_precision, tot_recall, tot_f1 = 0.0, 0.0, 0.0
sum_sup = 0.0
for label, sup in support.items():
precision[label] = float(tp[label])/(tp[label] + fp[label]) if label in tp else 0.0
recall[label] = float(tp[label])/(tp[label] + fn[label]) if label in tp else 0.0
f1_score[label] = 2 * float(precision[label] * recall[label])/(precision[label] + recall[label]) if precision[label] + recall[label] != 0 else 0.0
tot_f1 += sup * f1_score[label]
tot_precision += sup * precision[label]
tot_recall += sup * recall[label]
sum_sup += sup
for label, sup in support.items():
print label, precision[label], recall[label], f1_score[label], sup
return tot_precision/float(sum_sup), tot_recall/float(sum_sup), tot_f1/float(sum_sup), sum_sup
def get_accuracy(test_labels, pred_labels):
n = len(test_labels)
correct = np.sum([test_labels[idx] == pred_labels[idx] for idx in range(n)])
return float(correct)/n
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment