Skip to content

Instantly share code, notes, and snippets.

@varvara-l
Created April 27, 2016 17:00
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save varvara-l/d66450db8da44b8584c02f4b6c79745c to your computer and use it in GitHub Desktop.
from __future__ import division
import sys
import numpy as np
import math
from argparse import ArgumentParser
from multiprocessing import Pool
##########################################################################
#
# Compute statistical significance level using a randomization test [1]
# for ANY NUMBER of QE system outputs in WMT-15 format
#
# [1] Alexander Yeh. (2000) More accurate tests for the statistical
# significance of result differences. In Coling-2000.
#
##########################################################################
def parse_reference(reference):
tags = []
tag_map = {'OK': 1, 'BAD': 0}
for line in open(reference):
tags.append([tag_map[t] for t in line[:-1].decode('utf-8').strip().split()])
return tags
def parse_submission_seq(submission):
sub = []
tag_map = {'OK': 1, 'BAD': 0}
prev_tag = -1
for line in open(submission):
chunks = line[:-1].decode('utf-8').strip().split('\t')
s_idx, w_idx = int(chunks[1]), int(chunks[2])
if s_idx != prev_tag and w_idx == 0:
sub.append([])
sub[-1].append(tag_map[chunks[-1]])
prev_tag = s_idx
return sub
# extract tp, fp, tn, fn and number of spans (optional) for a system output
def get_statistics(true_tags, test_tags, spans=False, submission='sub'):
tp_all, fp_all, tn_all, fn_all = 0, 0, 0, 0
all_spans_true, all_spans_pred = 0, 0
seq_stats = []
for true_seq, test_seq in zip(true_tags, test_tags):
# extract number of spans for sequence correlation
n_spans_true, n_spans_pred = 0, 0
if spans:
prev_pred, prev_true = None, None
for tag in test_seq:
if tag != prev_pred:
n_spans_pred += 1
all_spans_pred += 1
prev_pred = tag
for tag in true_seq:
if tag != prev_true:
n_spans_true += 1
all_spans_true += 1
prev_true = tag
# subtract the beginning of sequence which shouldn't be counted
n_spans_true -= 1
n_spans_pred -= 1
all_spans_true -= 1
all_spans_pred -= 1
tp, fp, tn, fn = 0, 0, 0, 0
for true, test in zip(true_seq, test_seq):
if true == 1 and test == 1:
tp += 1
tp_all += 1
elif true == 1 and test == 0:
fn += 1
fn_all += 1
elif true == 0 and test == 0:
tn += 1
tn_all += 1
elif true == 0 and test == 1:
fp += 1
fp_all += 1
else:
print("Wrong combination of tags: {} and {}".format(true, test))
seq_stats.append((tp, fp, tn, fn, n_spans_true, n_spans_pred))
return seq_stats
# compare the statistics for a pair of sentences
def all_equal(sent1, sent2):
assert(len(sent1) == len(sent2))
return all([s1 == s2 for s1, s2 in zip(sent1, sent2)])
# compute f1 from statistics
def f1_multiply_statistics(system):
tp = sum([s[0] for s in system])
fp = sum([s[1] for s in system])
tn = sum([s[2] for s in system])
fn = sum([s[3] for s in system])
f1_ok = (2*tp)/(2*tp + fn + fp)
f1_bad = (2*tn)/(2*tn + fn + fp)
return f1_ok*f1_bad
# f1-bad from statistics
def f1_bad_statistics(system):
fp = sum([s[1] for s in system])
tn = sum([s[2] for s in system])
fn = sum([s[3] for s in system])
return (2*tn)/(2*tn + fn + fp)
def seq_cor_statistics(system):
seq_cor = []
for seq in system:
len_seq = seq[0] + seq[1] + seq[2] + seq[3]
lambda_1 = 0.5*len_seq/(seq[0] + seq[3]) if (seq[0] + seq[3]) > 0 else 0
lambda_0 = 0.5*len_seq/(seq[1] + seq[2]) if (seq[1] + seq[2]) > 0 else 0
acc_w = (lambda_1*seq[0] + lambda_0*seq[2])/len_seq
r = 0.0
if seq[4] == 0 and seq[5] == 0:
r = 1.0
elif seq[4] == 0 or seq[5] == 0:
r = 0.0
else:
r = min(seq[4]/seq[5], seq[5]/seq[4])
seq_cor.append(r*acc_w)
return np.average(seq_cor)
# compute MCC from statistics
def matthews_statistics(system):
tp = sum([s[0] for s in system])
fp = sum([s[1] for s in system])
tn = sum([s[2] for s in system])
fn = sum([s[3] for s in system])
try:
matt = (tp*tn - fp*fn)/math.sqrt((tp + fp)*(tp + fn)*(tn + fp)*(tn + fn))
return matt
except ZeroDivisionError:
return 0
# one bootstrap sample
# function for multi-threading
def one_random((idx, function, ref, sys1, sys2, common_tags)):
set_length = len(sys1)
if idx % 100000 == 0:
sys.stderr.write('.')
random_test1, random_test2 = [], []
# choose whether to swap the current sample or leave in place
choice = np.random.binomial(1, 0.5, size=set_length)
for idx, n in enumerate(choice):
if n == 0:
random_test1.append(sys1[idx])
random_test2.append(sys2[idx])
else:
random_test1.append(sys2[idx])
random_test2.append(sys1[idx])
assert(len(random_test1) + len(common_tags) == len(ref))
assert(len(random_test2) + len(common_tags) == len(ref))
score1 = function(random_test1 + common_tags)
score2 = function(random_test2 + common_tags)
return abs(score1 - score2)
# correct bootstrap test
def bootstrap_binary(function, true_tags, sys1_tags, sys2_tags, folds=1000000, bad_weight=0.5, threads=1):
# metric-sys1
score_sys1 = function(sys1_tags)
# metric-sys2
score_sys2 = function(sys2_tags)
# difference
diff_true = abs(score_sys1 - score_sys2)
# find the non-matching lines in two taggings
sys1_tags_diff, sys2_tags_diff, common_tags = [], [], []
for seq1, seq2 in zip(sys1_tags, sys2_tags):
if all_equal(seq1, seq2):
common_tags.append(seq1)
else:
sys1_tags_diff.append(seq1)
sys2_tags_diff.append(seq2)
if threads > 1:
pool = Pool(processes=threads)
differences = pool.map(one_random, [(i, function, true_tags, sys1_tags_diff, sys2_tags_diff, common_tags) for i in range(folds)])
else:
differences = []
for i in range(folds):
differences.append(one_random((i, function, true_tags, sys1_tags_diff, sys2_tags_diff, common_tags)))
p_val = sum([1 for s in differences if s > diff_true])/folds
return p_val
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("submissions", nargs="+", help="submissions (wmt15 format)")
parser.add_argument("reference", help="true tags (one sentence per line)")
parser.add_argument("--metric", help="metric to use: f1_multiply, f1_bad, seq_cor, matthews")
parser.add_argument("--threads", default='1', help="number of threads to use (default 1)")
parser.add_argument("--folds", default='1000000', help="number of runs of randomized test (default 1,000,000)")
parser.add_argument("--alpha", default='0.05', help="significance level (default 0.05)")
args = parser.parse_args()
threads = int(args.threads)
folds = int(args.folds)
alpha = float(args.alpha)
submissions = []
submission_names = args.submissions
for sub in submission_names:
submissions.append(parse_submission_seq(sub))
ref = parse_reference(args.reference)
spans = True if args.metric == 'seq_cor' or args.metric == 'doc_seq_cor' else False
function = ''
try:
function = locals()[args.metric + '_statistics']
except:
print("Unknown function name: {}".format(args.metric + '_statistics'))
sys.exit()
# compute the sentence-level statistics: tp, tn, fp, fn, number of spans
submission_statistics = []
for sub in submissions:
sub_stats = get_statistics(ref, sub, spans=spans)
submission_statistics.append(sub_stats)
# compute the scores
scores = []
for sub in submission_statistics:
scores.append(function(sub))
sorted_scores = sorted(enumerate(scores), key=lambda(k, v): v, reverse=True)
sys.stdout.write("--------%s metric values--------\n" % args.metric)
for (idx, sc) in sorted_scores:
sys.stdout.write("%s\t%f\n" % (submission_names[idx], sc))
init_ranking = [i for (i, v) in sorted_scores]
n_systems = len(init_ranking)
n_comparisons = (n_systems**2 - n_systems)/2
p_vals = []
for idx1, sys1 in enumerate(init_ranking):
sys.stderr.write('\n[%s]' % submission_names[sys1])
for idx2, sys2 in enumerate(init_ranking):
if idx2 > idx1:
p_vals.append((submission_names[sys1], submission_names[sys2], bootstrap_binary(function, ref, submission_statistics[sys1], submission_statistics[sys2], folds=folds, threads=threads)))
sys.stdout.write("\n\nNumber of trials: %d\n" % n_comparisons)
new_alpha = alpha/n_comparisons
sys.stdout.write("Significance level with Bonferroni correction: %f\n\n" % new_alpha)
sys.stdout.write("P-values:\n")
for idx1, idx2, val in p_vals:
sys.stdout.write("%s\t%s\t\t\t%f\n" % (idx1, idx2, val))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment