Created
July 10, 2023 15:25
-
-
Save erip/45898d616cc35527adf25887f4dbd11f to your computer and use it in GitHub Desktop.
Checks agreement between pronouns of a reference and MT system output.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import spacy | |
from statistics import mean | |
from argparse import ArgumentParser, FileType | |
def setup_argparse(): | |
parser = ArgumentParser() | |
parser.add_argument("-m", "--model", type=str, default="en_core_web_lg", help="The spaCy model to use for evaluation") | |
parser.add_argument("-r", "--ref", type=FileType("r"), help="The reference file") | |
parser.add_argument("-hyp", "--hyp", type=FileType("r"), help="The system output file") | |
parser.add_argument("-b", "--batch-size", type=int, default=64) | |
return parser | |
def get_match(ref_doc, hyp_doc): | |
ref_pronouns = {word.text for word in ref_doc if word.pos_ == "PRON"} | |
hyp_pronouns = {word.text for word in hyp_doc if word.pos_ == "PRON"} | |
if not (ref_pronouns | hyp_pronouns): | |
return 1.0 | |
return len(ref_pronouns & hyp_pronouns) / len(ref_pronouns | hyp_pronouns) | |
if __name__ == "__main__": | |
args = setup_argparse().parse_args() | |
nlp = spacy.load(args.model) | |
ref_docs = nlp.pipe(map(str.lower, map(str.strip, args.ref)), batch_size=args.batch_size) | |
hyp_docs = nlp.pipe(map(str.lower, map(str.strip, args.hyp)), batch_size=args.batch_size) | |
print(mean(map(lambda e: get_match(*e), zip(ref_docs, hyp_docs)))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment