Skip to content

Instantly share code, notes, and snippets.

@varvara-l
Last active April 27, 2016 16:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save varvara-l/4d95a7ba110904d8ced26ca8f14b0b28 to your computer and use it in GitHub Desktop.
Save varvara-l/4d95a7ba110904d8ced26ca8f14b0b28 to your computer and use it in GitHub Desktop.
from __future__ import print_function
from argparse import ArgumentParser
from sklearn.metrics import accuracy_score
###########################################################################################################
#
# Script to compute Sequence Correlation score for word-level binary QE system output in WMT-15 format:
# <METHOD NAME> <SEGMENT NUMBER> <WORD INDEX> <WORD> <BINARY SCORE>
#
###########################################################################################################
# check that two lists of sequences have the same number of elements
def check_word_tag(words_seq, tags_seq, dataset_name=''):
assert(len(words_seq) == len(tags_seq)), "Number of word and tag sequences doesn't match in dataset {}".format(dataset_name)
for idx, (words, tags) in enumerate(zip(words_seq, tags_seq)):
assert(len(words) == len(tags)), "Numbers of words and tags don't match in sequence {} of dataset {}".format(idx, dataset_name)
def parse_submission(ref_txt_file, ref_tags_file, submission):
tag_map = {'OK': 1, 'BAD': 0}
# parse test tags
true_words = []
for line in open(ref_txt_file):
true_words.append(line[:-1].decode('utf-8').split())
# parse test txt
true_tags = []
for line in open(ref_tags_file):
true_tags.append([tag_map[t] for t in line[:-1].decode('utf-8').split()])
check_word_tag(true_words, true_tags, dataset_name='reference')
# parse and check the submission
test_tags = [[] for i in range(len(true_tags))]
for idx, line in enumerate(open(submission)):
chunks = line[:-1].decode('utf-8').strip('\r').split('\t')
test_tags[int(chunks[1])].append(tag_map[chunks[4]])
check_word_tag(true_words, test_tags, dataset_name='prediction')
return true_tags, test_tags
# accuracy weighted by the ratio of the number of spans in the reference and the hypothesis
def sequence_correlation_simple(true_tags, test_tags, bad_weight=0.5):
seq_corr_all = []
for true_seq, test_seq in zip(true_tags, test_tags):
# number of good and bad tags in the reference
n_tags_1 = sum([1 for t in true_seq if t == 1])
n_tags_0 = sum([1 for t in true_seq if t == 0])
lambda_0 = bad_weight*len(true_seq)/(n_tags_0) if n_tags_0 != 0 else 0
lambda_1 = (1 - bad_weight)*len(true_seq)/(n_tags_1) if n_tags_1 != 0 else 0
# number of borders between tag sequences in hypothesis and reference
prev_true, prev_pred = None, None
n_spans_true, n_spans_pred = 0, 0
for tag in test_seq:
if tag != prev_pred:
n_spans_pred += 1
prev_pred = tag
for tag in true_seq:
if tag != prev_true:
n_spans_true += 1
prev_true = tag
# subtract the beginning of sequence which shouldn't be counted
n_spans_true -= 1
n_spans_pred -= 1
# weights list contains lambdas for good and bad tags
weights = []
for t in true_seq:
if t == 1:
weights.append(lambda_1)
elif t == 0:
weights.append(lambda_0)
else:
print("Unknown reference tag: {}".format(t))
assert(len(weights) == len(true_seq)), "Expected weights array len {}, got {}".format(len(weights), len(true_tags))
acc = accuracy_score(true_seq, test_seq, sample_weight=weights)
# penalises any difference in the number of spans between the reference and the hypothesis
if n_spans_true == 0 and n_spans_pred == 0:
ratio = 1
elif n_spans_true == 0 or n_spans_pred == 0:
ratio = 0
else:
ratio = min(n_spans_pred/n_spans_true, n_spans_true/n_spans_pred)
seq_corr_all.append(acc * ratio)
return seq_corr_all, np.average(seq_corr_all)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("ref_txt", action="store", help="test target text (one line per sentence)")
parser.add_argument("ref_tags", action="store", help="test tags (one line per sentence)")
parser.add_argument("submission", action="store", help="submission (wmt15 format)")
parser.add_argument("--sequence", dest='seq_corr', default=None, help="file to store sentence-level sequence correlation scores (not saved if no file provided)")
true_tags, test_tags = parse_submission(args.ref_txt, args.ref_tags, args.submission)
seq_corr = sequence_correlation_simple(true_tags, test_tags, bad_weight=bad_weight)
if args.seq_corr is not None:
seq_corr_out = open(args.seq_corr, 'w')
for sc in seq_corr[0]:
seq_corr_out.write("%f\n" % sc)
print("Sequence correlation: {}".format(seq_corr[1]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment