SemEval command-line game
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/env python3 | |
""" | |
Play the STS sentiment evaluation game! Test your ability to predict the gold | |
standard label for the TSV containing gold labels and sentence pairs. | |
Try loading the question-question set from SemEval Semantic Textual Similarity. | |
https://github.com/brmson/dataset-sts | |
""" | |
import argparse | |
import csv | |
import sys | |
# https://stackoverflow.com/a/5713856 | |
def pearsonr(x, y): | |
assert len(x) == len(y) | |
n = len(x) | |
sum_x = float(sum(x)) | |
sum_y = float(sum(y)) | |
sum_x_sq = sum(x_i ** 2 for x_i in x) | |
sum_y_sq = sum(y_i ** 2 for y_i in y) | |
psum = sum(x_i * y_i for x_i, y_i in zip(x, y)) | |
num = psum - (sum_x * sum_y / n) | |
den = ((sum_x_sq - (sum_x ** 2) / n) * (sum_y_sq - (sum_y ** 2) / n)) ** 0.5 | |
if den == 0: | |
return 0 | |
return num / den | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('tsv', nargs='?', type=argparse.FileType('r'), default=sys.stdin) | |
args = parser.parse_args() | |
gold_labels = [] | |
pred_labels = [] | |
try: | |
for gold, s1, s2 in csv.reader(args.tsv, delimiter='\t'): | |
if not gold: | |
continue | |
print('Sentence 1: ' + s1) | |
print('Sentence 2: ' + s2) | |
pred = input('Rate [0-5]: ') | |
while not pred.isdigit() or not (0 <= int(pred) <= 5): | |
pred = input('Rate [0-5]: ') | |
print('STS Rating: ' + gold) | |
print() | |
gold_labels.append(int(gold)) | |
pred_labels.append(int(pred)) | |
finally: | |
if gold_labels and pred_labels: | |
print() | |
print('Correlation: ' + str(pearsonr(gold_labels, pred_labels))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment