Created
April 27, 2016 17:00
Star
You must be signed in to star a gist
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
import sys | |
import numpy as np | |
import math | |
from argparse import ArgumentParser | |
from multiprocessing import Pool | |
########################################################################## | |
# | |
# Compute statistical significance level using a randomization test [1] | |
# for ANY NUMBER of QE system outputs in WMT-15 format | |
# | |
# [1] Alexander Yeh. (2000) More accurate tests for the statistical | |
# significance of result differences. In Coling-2000. | |
# | |
########################################################################## | |
def parse_reference(reference): | |
tags = [] | |
tag_map = {'OK': 1, 'BAD': 0} | |
for line in open(reference): | |
tags.append([tag_map[t] for t in line[:-1].decode('utf-8').strip().split()]) | |
return tags | |
def parse_submission_seq(submission): | |
sub = [] | |
tag_map = {'OK': 1, 'BAD': 0} | |
prev_tag = -1 | |
for line in open(submission): | |
chunks = line[:-1].decode('utf-8').strip().split('\t') | |
s_idx, w_idx = int(chunks[1]), int(chunks[2]) | |
if s_idx != prev_tag and w_idx == 0: | |
sub.append([]) | |
sub[-1].append(tag_map[chunks[-1]]) | |
prev_tag = s_idx | |
return sub | |
# extract tp, fp, tn, fn and number of spans (optional) for a system output | |
def get_statistics(true_tags, test_tags, spans=False, submission='sub'): | |
tp_all, fp_all, tn_all, fn_all = 0, 0, 0, 0 | |
all_spans_true, all_spans_pred = 0, 0 | |
seq_stats = [] | |
for true_seq, test_seq in zip(true_tags, test_tags): | |
# extract number of spans for sequence correlation | |
n_spans_true, n_spans_pred = 0, 0 | |
if spans: | |
prev_pred, prev_true = None, None | |
for tag in test_seq: | |
if tag != prev_pred: | |
n_spans_pred += 1 | |
all_spans_pred += 1 | |
prev_pred = tag | |
for tag in true_seq: | |
if tag != prev_true: | |
n_spans_true += 1 | |
all_spans_true += 1 | |
prev_true = tag | |
# subtract the beginning of sequence which shouldn't be counted | |
n_spans_true -= 1 | |
n_spans_pred -= 1 | |
all_spans_true -= 1 | |
all_spans_pred -= 1 | |
tp, fp, tn, fn = 0, 0, 0, 0 | |
for true, test in zip(true_seq, test_seq): | |
if true == 1 and test == 1: | |
tp += 1 | |
tp_all += 1 | |
elif true == 1 and test == 0: | |
fn += 1 | |
fn_all += 1 | |
elif true == 0 and test == 0: | |
tn += 1 | |
tn_all += 1 | |
elif true == 0 and test == 1: | |
fp += 1 | |
fp_all += 1 | |
else: | |
print("Wrong combination of tags: {} and {}".format(true, test)) | |
seq_stats.append((tp, fp, tn, fn, n_spans_true, n_spans_pred)) | |
return seq_stats | |
# compare the statistics for a pair of sentences | |
def all_equal(sent1, sent2): | |
assert(len(sent1) == len(sent2)) | |
return all([s1 == s2 for s1, s2 in zip(sent1, sent2)]) | |
# compute f1 from statistics | |
def f1_multiply_statistics(system): | |
tp = sum([s[0] for s in system]) | |
fp = sum([s[1] for s in system]) | |
tn = sum([s[2] for s in system]) | |
fn = sum([s[3] for s in system]) | |
f1_ok = (2*tp)/(2*tp + fn + fp) | |
f1_bad = (2*tn)/(2*tn + fn + fp) | |
return f1_ok*f1_bad | |
# f1-bad from statistics | |
def f1_bad_statistics(system): | |
fp = sum([s[1] for s in system]) | |
tn = sum([s[2] for s in system]) | |
fn = sum([s[3] for s in system]) | |
return (2*tn)/(2*tn + fn + fp) | |
def seq_cor_statistics(system): | |
seq_cor = [] | |
for seq in system: | |
len_seq = seq[0] + seq[1] + seq[2] + seq[3] | |
lambda_1 = 0.5*len_seq/(seq[0] + seq[3]) if (seq[0] + seq[3]) > 0 else 0 | |
lambda_0 = 0.5*len_seq/(seq[1] + seq[2]) if (seq[1] + seq[2]) > 0 else 0 | |
acc_w = (lambda_1*seq[0] + lambda_0*seq[2])/len_seq | |
r = 0.0 | |
if seq[4] == 0 and seq[5] == 0: | |
r = 1.0 | |
elif seq[4] == 0 or seq[5] == 0: | |
r = 0.0 | |
else: | |
r = min(seq[4]/seq[5], seq[5]/seq[4]) | |
seq_cor.append(r*acc_w) | |
return np.average(seq_cor) | |
# compute MCC from statistics | |
def matthews_statistics(system): | |
tp = sum([s[0] for s in system]) | |
fp = sum([s[1] for s in system]) | |
tn = sum([s[2] for s in system]) | |
fn = sum([s[3] for s in system]) | |
try: | |
matt = (tp*tn - fp*fn)/math.sqrt((tp + fp)*(tp + fn)*(tn + fp)*(tn + fn)) | |
return matt | |
except ZeroDivisionError: | |
return 0 | |
# one bootstrap sample | |
# function for multi-threading | |
def one_random((idx, function, ref, sys1, sys2, common_tags)): | |
set_length = len(sys1) | |
if idx % 100000 == 0: | |
sys.stderr.write('.') | |
random_test1, random_test2 = [], [] | |
# choose whether to swap the current sample or leave in place | |
choice = np.random.binomial(1, 0.5, size=set_length) | |
for idx, n in enumerate(choice): | |
if n == 0: | |
random_test1.append(sys1[idx]) | |
random_test2.append(sys2[idx]) | |
else: | |
random_test1.append(sys2[idx]) | |
random_test2.append(sys1[idx]) | |
assert(len(random_test1) + len(common_tags) == len(ref)) | |
assert(len(random_test2) + len(common_tags) == len(ref)) | |
score1 = function(random_test1 + common_tags) | |
score2 = function(random_test2 + common_tags) | |
return abs(score1 - score2) | |
# correct bootstrap test | |
def bootstrap_binary(function, true_tags, sys1_tags, sys2_tags, folds=1000000, bad_weight=0.5, threads=1): | |
# metric-sys1 | |
score_sys1 = function(sys1_tags) | |
# metric-sys2 | |
score_sys2 = function(sys2_tags) | |
# difference | |
diff_true = abs(score_sys1 - score_sys2) | |
# find the non-matching lines in two taggings | |
sys1_tags_diff, sys2_tags_diff, common_tags = [], [], [] | |
for seq1, seq2 in zip(sys1_tags, sys2_tags): | |
if all_equal(seq1, seq2): | |
common_tags.append(seq1) | |
else: | |
sys1_tags_diff.append(seq1) | |
sys2_tags_diff.append(seq2) | |
if threads > 1: | |
pool = Pool(processes=threads) | |
differences = pool.map(one_random, [(i, function, true_tags, sys1_tags_diff, sys2_tags_diff, common_tags) for i in range(folds)]) | |
else: | |
differences = [] | |
for i in range(folds): | |
differences.append(one_random((i, function, true_tags, sys1_tags_diff, sys2_tags_diff, common_tags))) | |
p_val = sum([1 for s in differences if s > diff_true])/folds | |
return p_val | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
parser.add_argument("submissions", nargs="+", help="submissions (wmt15 format)") | |
parser.add_argument("reference", help="true tags (one sentence per line)") | |
parser.add_argument("--metric", help="metric to use: f1_multiply, f1_bad, seq_cor, matthews") | |
parser.add_argument("--threads", default='1', help="number of threads to use (default 1)") | |
parser.add_argument("--folds", default='1000000', help="number of runs of randomized test (default 1,000,000)") | |
parser.add_argument("--alpha", default='0.05', help="significance level (default 0.05)") | |
args = parser.parse_args() | |
threads = int(args.threads) | |
folds = int(args.folds) | |
alpha = float(args.alpha) | |
submissions = [] | |
submission_names = args.submissions | |
for sub in submission_names: | |
submissions.append(parse_submission_seq(sub)) | |
ref = parse_reference(args.reference) | |
spans = True if args.metric == 'seq_cor' or args.metric == 'doc_seq_cor' else False | |
function = '' | |
try: | |
function = locals()[args.metric + '_statistics'] | |
except: | |
print("Unknown function name: {}".format(args.metric + '_statistics')) | |
sys.exit() | |
# compute the sentence-level statistics: tp, tn, fp, fn, number of spans | |
submission_statistics = [] | |
for sub in submissions: | |
sub_stats = get_statistics(ref, sub, spans=spans) | |
submission_statistics.append(sub_stats) | |
# compute the scores | |
scores = [] | |
for sub in submission_statistics: | |
scores.append(function(sub)) | |
sorted_scores = sorted(enumerate(scores), key=lambda(k, v): v, reverse=True) | |
sys.stdout.write("--------%s metric values--------\n" % args.metric) | |
for (idx, sc) in sorted_scores: | |
sys.stdout.write("%s\t%f\n" % (submission_names[idx], sc)) | |
init_ranking = [i for (i, v) in sorted_scores] | |
n_systems = len(init_ranking) | |
n_comparisons = (n_systems**2 - n_systems)/2 | |
p_vals = [] | |
for idx1, sys1 in enumerate(init_ranking): | |
sys.stderr.write('\n[%s]' % submission_names[sys1]) | |
for idx2, sys2 in enumerate(init_ranking): | |
if idx2 > idx1: | |
p_vals.append((submission_names[sys1], submission_names[sys2], bootstrap_binary(function, ref, submission_statistics[sys1], submission_statistics[sys2], folds=folds, threads=threads))) | |
sys.stdout.write("\n\nNumber of trials: %d\n" % n_comparisons) | |
new_alpha = alpha/n_comparisons | |
sys.stdout.write("Significance level with Bonferroni correction: %f\n\n" % new_alpha) | |
sys.stdout.write("P-values:\n") | |
for idx1, idx2, val in p_vals: | |
sys.stdout.write("%s\t%s\t\t\t%f\n" % (idx1, idx2, val)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment