Last active
March 9, 2021 12:48
-
-
Save bricksdont/32b115f0e01a246dfd31daeda390578c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
# Authors: Chantal Amrhein and Mathias Müller | |
# University of Zurich | |
# imported modules | |
import argparse # for argument handling | |
import codecs # for opening files | |
from nltk import ngrams # for creating ngrams | |
from nltk.tokenize import word_tokenize # for tokenization | |
import numpy as np # for exp function | |
from collections import Counter # for counting ngrams | |
from typing import List, Tuple | |
################################################################################ | |
Token = str | |
Sentence = List[Token] | |
Sentences = List[Sentence] | |
################################################################################ | |
def parse_args() -> argparse.Namespace: | |
""" Parse arguments given via command lines | |
Keyword arguments: None | |
Returns: args - object of class argparse.Namespace | |
""" | |
parser = argparse.ArgumentParser() | |
# required arguments: a target file, the source file translated into the target language | |
parser.add_argument("-trg", "--target", required=True, action="store", dest="trg", help="target text") | |
parser.add_argument("-trans", "--translation", required=True, action="store", dest="trans", help="source translation") | |
parser.add_argument("-v", "--verbose", required=False, action="store_true", help="Print more verbose statistics") | |
args = parser.parse_args() | |
return args | |
################################################################################ | |
def tokenize_file(filename: str) -> Sentences: | |
""" Open a file, read and tokenize content | |
Keyword arguments: filename - string, name of a file | |
Returns: text - list, tokenized text of file | |
""" | |
sentences = [] | |
with codecs.open(filename,"r","utf-8") as infile: | |
for line in infile: | |
tokenized = tokenize(line) | |
sentences.append(tokenized) | |
return sentences | |
################################################################################ | |
def tokenize(sentence: str) -> Sentence: | |
""" Tokenize a given sentence | |
Keyword arguments: sentence - string, a sentence | |
Returns: sentence - list of tokens | |
""" | |
# tokenize sentence with nltk tokenizer | |
tokenized = word_tokenize(sentence) | |
return tokenized | |
################################################################################ | |
def compute_ngram_counts_per_sentence(hypothesis: Sentence, reference: Sentence, n: int) -> Tuple[int, int]: | |
""" | |
:param hypothesis: | |
:param reference: | |
:param n: | |
:return: | |
""" | |
# create ngrams of size n, initialise frequency dictionary for clipping | |
ngrams_ref = ngrams(reference, n) | |
ngrams_sent = ngrams(hypothesis, n) | |
clip_dictionary = Counter(ngrams_ref) | |
# initialise counters | |
total = 0 | |
correct = 0 | |
# check how many ngrams in the sentence also occur in the reference (includes clipping) | |
for ngram in ngrams_sent: | |
total += 1 | |
if clip_dictionary[ngram] > 0: | |
correct += 1 | |
clip_dictionary[ngram] -= 1 | |
return correct, total | |
################################################################################ | |
def compute_ngram_counts(hypothesis: Sentences, reference: Sentences, n: int) -> Tuple[int, int]: | |
""" Compute the ngram precision for a given sentence and ngram size | |
Keyword arguments: hypothesis - list of tokens, a translated sentence | |
reference - list of tokens, reference translation of sentences | |
n - integer, number of n-grams considered | |
Returns: precision - float, ngram precision for given sentence and ngram size | |
""" | |
correct, total = 0, 0 | |
for hyp_sentence, ref_sentence in zip(hypothesis, reference): | |
correct_per_sentence, total_per_sentence = compute_ngram_counts_per_sentence(hyp_sentence, ref_sentence, n) | |
correct += correct_per_sentence | |
total += total_per_sentence | |
return correct, total | |
def compute_length(sentences: Sentences) -> int: | |
""" | |
:param sentences: | |
:return: | |
""" | |
length = 0 | |
for sentence in sentences: | |
length += len(sentence) | |
return length | |
################################################################################ | |
def compute_bleu_score(hypothesis: Sentences, | |
reference: Sentences, | |
n: int=4) -> Tuple[float, List[Tuple[int, int]], List[float], List[int], float]: | |
""" Compute the BLEU score for a given sentence | |
Keyword arguments: hypothesis - list of tokens, a translated sentence | |
reference - list of tokens, reference translation of sentences | |
n - integer, number of n-grams considered (default = 4) | |
Returns: score - float, BLEU score for given sentence | |
""" | |
counts = [] | |
precisions = [] | |
for ngram_order in range(1, n+1): | |
correct, total = compute_ngram_counts(hypothesis, reference, ngram_order) | |
# keep counts for verbose output | |
counts.append((correct, total)) | |
# compute precision for this ngram order | |
precision = correct / float(total) | |
precisions.append(precision) | |
# compute brevity penalty | |
hyp_length = compute_length(hypothesis) | |
ref_length = compute_length(reference) | |
lengths = [ref_length, hyp_length] | |
bp = min(1.0,np.exp(1 - ref_length/hyp_length)) | |
# compute geometric mean of ngram precisions | |
combined_precision = np.prod(precisions) | |
p = combined_precision**(1/n) | |
# compute final bleu score | |
score = bp * p | |
return score, counts, precisions, lengths, bp | |
################################################################################ | |
def main(args: argparse.Namespace) -> None: | |
""" Main function to compute BLEU score of a given file and a reference translation | |
Keyword arguments: args - object of class argparse.Namespace | |
Returns: None | |
""" | |
# read target and translation files and tokenize text | |
trg = tokenize_file(args.trg) | |
trans = tokenize_file(args.trans) | |
# compute bleu score | |
bleu_score, counts, precisions, lengths, bp = compute_bleu_score(trans,trg) | |
# print information | |
print("BLEU score for {0:s}:".format(args.trans)) | |
if args.verbose: | |
print("BLEU={score} COUNTS={counts} PRECISIONS={precisions} LENGTHS={lengths} BP={bp}".format( | |
score="%.4f" % bleu_score, | |
counts=["%d/%d" % (correct, total) for correct, total in counts], | |
precisions=["%.4f" % p for p in precisions], | |
lengths=lengths, | |
bp=bp | |
)) | |
else: | |
print("\t","%.4f" % bleu_score) | |
################################################################################ | |
if __name__ == "__main__": | |
args = parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment