Last active
April 27, 2016 16:20
-
-
Save varvara-l/4d95a7ba110904d8ced26ca8f14b0b28 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
from argparse import ArgumentParser | |
from sklearn.metrics import accuracy_score | |
########################################################################################################### | |
# | |
# Script to compute Sequence Correlation score for word-level binary QE system output in WMT-15 format: | |
# <METHOD NAME> <SEGMENT NUMBER> <WORD INDEX> <WORD> <BINARY SCORE> | |
# | |
########################################################################################################### | |
# check that two lists of sequences have the same number of elements | |
def check_word_tag(words_seq, tags_seq, dataset_name=''): | |
assert(len(words_seq) == len(tags_seq)), "Number of word and tag sequences doesn't match in dataset {}".format(dataset_name) | |
for idx, (words, tags) in enumerate(zip(words_seq, tags_seq)): | |
assert(len(words) == len(tags)), "Numbers of words and tags don't match in sequence {} of dataset {}".format(idx, dataset_name) | |
def parse_submission(ref_txt_file, ref_tags_file, submission): | |
tag_map = {'OK': 1, 'BAD': 0} | |
# parse test tags | |
true_words = [] | |
for line in open(ref_txt_file): | |
true_words.append(line[:-1].decode('utf-8').split()) | |
# parse test txt | |
true_tags = [] | |
for line in open(ref_tags_file): | |
true_tags.append([tag_map[t] for t in line[:-1].decode('utf-8').split()]) | |
check_word_tag(true_words, true_tags, dataset_name='reference') | |
# parse and check the submission | |
test_tags = [[] for i in range(len(true_tags))] | |
for idx, line in enumerate(open(submission)): | |
chunks = line[:-1].decode('utf-8').strip('\r').split('\t') | |
test_tags[int(chunks[1])].append(tag_map[chunks[4]]) | |
check_word_tag(true_words, test_tags, dataset_name='prediction') | |
return true_tags, test_tags | |
# accuracy weighted by the ratio of the number of spans in the reference and the hypothesis | |
def sequence_correlation_simple(true_tags, test_tags, bad_weight=0.5): | |
seq_corr_all = [] | |
for true_seq, test_seq in zip(true_tags, test_tags): | |
# number of good and bad tags in the reference | |
n_tags_1 = sum([1 for t in true_seq if t == 1]) | |
n_tags_0 = sum([1 for t in true_seq if t == 0]) | |
lambda_0 = bad_weight*len(true_seq)/(n_tags_0) if n_tags_0 != 0 else 0 | |
lambda_1 = (1 - bad_weight)*len(true_seq)/(n_tags_1) if n_tags_1 != 0 else 0 | |
# number of borders between tag sequences in hypothesis and reference | |
prev_true, prev_pred = None, None | |
n_spans_true, n_spans_pred = 0, 0 | |
for tag in test_seq: | |
if tag != prev_pred: | |
n_spans_pred += 1 | |
prev_pred = tag | |
for tag in true_seq: | |
if tag != prev_true: | |
n_spans_true += 1 | |
prev_true = tag | |
# subtract the beginning of sequence which shouldn't be counted | |
n_spans_true -= 1 | |
n_spans_pred -= 1 | |
# weights list contains lambdas for good and bad tags | |
weights = [] | |
for t in true_seq: | |
if t == 1: | |
weights.append(lambda_1) | |
elif t == 0: | |
weights.append(lambda_0) | |
else: | |
print("Unknown reference tag: {}".format(t)) | |
assert(len(weights) == len(true_seq)), "Expected weights array len {}, got {}".format(len(weights), len(true_tags)) | |
acc = accuracy_score(true_seq, test_seq, sample_weight=weights) | |
# penalises any difference in the number of spans between the reference and the hypothesis | |
if n_spans_true == 0 and n_spans_pred == 0: | |
ratio = 1 | |
elif n_spans_true == 0 or n_spans_pred == 0: | |
ratio = 0 | |
else: | |
ratio = min(n_spans_pred/n_spans_true, n_spans_true/n_spans_pred) | |
seq_corr_all.append(acc * ratio) | |
return seq_corr_all, np.average(seq_corr_all) | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
parser.add_argument("ref_txt", action="store", help="test target text (one line per sentence)") | |
parser.add_argument("ref_tags", action="store", help="test tags (one line per sentence)") | |
parser.add_argument("submission", action="store", help="submission (wmt15 format)") | |
parser.add_argument("--sequence", dest='seq_corr', default=None, help="file to store sentence-level sequence correlation scores (not saved if no file provided)") | |
true_tags, test_tags = parse_submission(args.ref_txt, args.ref_tags, args.submission) | |
seq_corr = sequence_correlation_simple(true_tags, test_tags, bad_weight=bad_weight) | |
if args.seq_corr is not None: | |
seq_corr_out = open(args.seq_corr, 'w') | |
for sc in seq_corr[0]: | |
seq_corr_out.write("%f\n" % sc) | |
print("Sequence correlation: {}".format(seq_corr[1])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment