This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
import sys | |
import numpy as np | |
import math | |
from argparse import ArgumentParser | |
from multiprocessing import Pool | |
########################################################################## | |
# | |
# Compute statistical significance level using a randomization test [1] | |
# for two QE system outputs in WMT-15 format | |
# | |
# [1] Alexander Yeh. (2000) More accurate tests for the statistical | |
# significance of result differences. In Coling-2000. | |
# | |
########################################################################## | |
def parse_reference(reference): | |
tags = [] | |
tag_map = {'OK': 1, 'BAD': 0} | |
for line in open(reference): | |
tags.append([tag_map[t] for t in line[:-1].decode('utf-8').strip().split()]) | |
return tags | |
def parse_submission_seq(submission): | |
sub = [] | |
tag_map = {'OK': 1, 'BAD': 0} | |
prev_tag = -1 | |
for line in open(submission): | |
chunks = line[:-1].decode('utf-8').strip().split('\t') | |
s_idx, w_idx = int(chunks[1]), int(chunks[2]) | |
if s_idx != prev_tag and w_idx == 0: | |
sub.append([]) | |
sub[-1].append(tag_map[chunks[-1]]) | |
prev_tag = s_idx | |
return sub | |
# extract tp, fp, tn, fn and number of spans (optional) for a system output | |
def get_statistics(true_tags, test_tags, spans=False, submission='sub'): | |
tp_all, fp_all, tn_all, fn_all = 0, 0, 0, 0 | |
all_spans_true, all_spans_pred = 0, 0 | |
seq_stats = [] | |
for true_seq, test_seq in zip(true_tags, test_tags): | |
# extract number of spans for sequence correlation | |
n_spans_true, n_spans_pred = 0, 0 | |
if spans: | |
prev_pred, prev_true = None, None | |
for tag in test_seq: | |
if tag != prev_pred: | |
n_spans_pred += 1 | |
all_spans_pred += 1 | |
prev_pred = tag | |
for tag in true_seq: | |
if tag != prev_true: | |
n_spans_true += 1 | |
all_spans_true += 1 | |
prev_true = tag | |
# subtract the beginning of sequence which shouldn't be counted | |
n_spans_true -= 1 | |
n_spans_pred -= 1 | |
all_spans_true -= 1 | |
all_spans_pred -= 1 | |
tp, fp, tn, fn = 0, 0, 0, 0 | |
for true, test in zip(true_seq, test_seq): | |
if true == 1 and test == 1: | |
tp += 1 | |
tp_all += 1 | |
elif true == 1 and test == 0: | |
fn += 1 | |
fn_all += 1 | |
elif true == 0 and test == 0: | |
tn += 1 | |
tn_all += 1 | |
elif true == 0 and test == 1: | |
fp += 1 | |
fp_all += 1 | |
else: | |
print("Wrong combination of tags: {} and {}".format(true, test)) | |
seq_stats.append((tp, fp, tn, fn, n_spans_true, n_spans_pred)) | |
return seq_stats | |
# compare the statistics for a pair of sentences | |
def all_equal(sent1, sent2): | |
assert(len(sent1) == len(sent2)) | |
return all([s1 == s2 for s1, s2 in zip(sent1, sent2)]) | |
# compute f1 from statistics | |
def f1_multiply_statistics(system): | |
tp = sum([s[0] for s in system]) | |
fp = sum([s[1] for s in system]) | |
tn = sum([s[2] for s in system]) | |
fn = sum([s[3] for s in system]) | |
f1_ok = (2*tp)/(2*tp + fn + fp) | |
f1_bad = (2*tn)/(2*tn + fn + fp) | |
return f1_ok*f1_bad | |
# f1-bad from statistics | |
def f1_bad_statistics(system): | |
fp = sum([s[1] for s in system]) | |
tn = sum([s[2] for s in system]) | |
fn = sum([s[3] for s in system]) | |
return (2*tn)/(2*tn + fn + fp) | |
# compute Sequence Correlation from statistics | |
def seq_cor_statistics(system): | |
seq_cor = [] | |
for seq in system: | |
len_seq = seq[0] + seq[1] + seq[2] + seq[3] | |
lambda_1 = 0.5*len_seq/(seq[0] + seq[3]) if (seq[0] + seq[3]) > 0 else 0 | |
lambda_0 = 0.5*len_seq/(seq[1] + seq[2]) if (seq[1] + seq[2]) > 0 else 0 | |
acc_w = (lambda_1*seq[0] + lambda_0*seq[2])/len_seq | |
r = 0.0 | |
if seq[4] == 0 and seq[5] == 0: | |
r = 1.0 | |
elif seq[4] == 0 or seq[5] == 0: | |
r = 0.0 | |
else: | |
r = min(seq[4]/seq[5], seq[5]/seq[4]) | |
seq_cor.append(r*acc_w) | |
return np.average(seq_cor) | |
# compute MCC from statistics | |
def matthews_statistics(system): | |
tp = sum([s[0] for s in system]) | |
fp = sum([s[1] for s in system]) | |
tn = sum([s[2] for s in system]) | |
fn = sum([s[3] for s in system]) | |
try: | |
matt = (tp*tn - fp*fn)/math.sqrt((tp + fp)*(tp + fn)*(tn + fp)*(tn + fn)) | |
return matt | |
except ZeroDivisionError: | |
return 0 | |
# one bootstrap sample | |
# function for multi-threading | |
def one_random((idx, function, ref, sys1, sys2, common_tags)): | |
set_length = len(sys1) | |
if idx % 100000 == 0: | |
sys.stderr.write('.') | |
random_test1, random_test2 = [], [] | |
# choose whether to swap the current sample or leave in place | |
choice = np.random.binomial(1, 0.5, size=set_length) | |
for idx, n in enumerate(choice): | |
if n == 0: | |
random_test1.append(sys1[idx]) | |
random_test2.append(sys2[idx]) | |
else: | |
random_test1.append(sys2[idx]) | |
random_test2.append(sys1[idx]) | |
assert(len(random_test1) + len(common_tags) == len(ref)) | |
assert(len(random_test2) + len(common_tags) == len(ref)) | |
score1 = function(random_test1 + common_tags) | |
score2 = function(random_test2 + common_tags) | |
return abs(score1 - score2) | |
# correct bootstrap test | |
def bootstrap_binary(function, true_tags, sys1_tags, sys2_tags, folds=1000000, bad_weight=0.5, threads=10): | |
# metric-sys1 | |
score_sys1 = function(sys1_tags) | |
# metric-sys2 | |
score_sys2 = function(sys2_tags) | |
# difference | |
diff_true = abs(score_sys1 - score_sys2) | |
# find the non-matching lines in two taggings | |
sys1_tags_diff, sys2_tags_diff, common_tags = [], [], [] | |
for seq1, seq2 in zip(sys1_tags, sys2_tags): | |
if all_equal(seq1, seq2): | |
common_tags.append(seq1) | |
else: | |
sys1_tags_diff.append(seq1) | |
sys2_tags_diff.append(seq2) | |
if threads > 1: | |
pool = Pool(processes=threads) | |
differences = pool.map(one_random, [(i, function, true_tags, sys1_tags_diff, sys2_tags_diff, common_tags) for i in range(folds)]) | |
else: | |
differences = [] | |
for i in range(folds): | |
differences.append(one_random((i, function, true_tags, sys1_tags_diff, sys2_tags_diff, common_tags))) | |
p_val = sum([1 for s in differences if s > diff_true])/folds | |
return p_val | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
parser.add_argument("submission1", help="first submission to compare (WMT-15 format)") | |
parser.add_argument("submission2", help="second submission to compare (WMT-15 format)") | |
parser.add_argument("reference", help="reference tags") | |
parser.add_argument("--metric", help="metric to use: f1_multiply, f1_bad, seq_cor, matthews (multiplication of F1-OK and F1-BAD scores, F1-BAD score, Sequence Correlation, Matthews correlation coefficient)") | |
parser.add_argument("--threads", default='1', help="number of threads to use (default 1)") | |
parser.add_argument("--folds", default='1000000', help="number of runs of randomized test (default 1,000,000)") | |
args = parser.parse_args() | |
threads = int(args.threads) | |
folds = int(args.folds) | |
sub1 = parse_submission_seq(args.submission1) | |
sub2 = parse_submission_seq(args.submission2) | |
ref = parse_reference(args.reference) | |
spans = True if args.metric == 'seq_cor' else False | |
# function to compute metrics | |
function = '' | |
try: | |
function = locals()[args.metric + '_statistics'] | |
except: | |
print("Unknown function name: {}".format(args.metric + '_statistics')) | |
sys.exit() | |
# compute the sentence-level statistics: tp, tn, fp, fn, number of spans | |
sub1_stats = get_statistics(ref, sub1, spans=spans) | |
sub2_stats = get_statistics(ref, sub2, spans=spans) | |
# compute scores | |
sys.stdout.write("%s\t%f\n" % (args.submission1, function(sub1_stats))) | |
sys.stdout.write("%s\t%f\n" % (args.submission2, function(sub2_stats))) | |
# compute p-value | |
p_val = bootstrap_binary(function, ref, sub1_stats, sub2_stats, folds=folds, threads=threads) | |
sys.stdout.write("\nP-value: %.7f\n" % p_val) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment