Last active
August 17, 2016 10:40
-
-
Save yohokuno/a6361d8292f9b1009419a6a98c9b483c to your computer and use it in GitHub Desktop.
BLEU: automatic evaluation for machine translation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import argparse | |
import math | |
def ngrams(sentence, n): | |
for i in range(len(sentence) - n): | |
yield tuple(sentence[i:i + n + 1]) | |
def precision(references, system, n): | |
all_matched_ngrams = 0 | |
all_system_ngrams = 0 | |
for references_sentence, system_sentence in zip(references, system): | |
# TODO: allow duplicated n-grams in sentences | |
reference_ngrams = set() | |
for reference_sentence in references_sentence: | |
reference_ngrams = reference_ngrams.union(set(ngrams(reference_sentence, n))) | |
system_ngrams = set(ngrams(system_sentence, n)) | |
matched_ngrams = system_ngrams.intersection(reference_ngrams) | |
all_matched_ngrams += len(matched_ngrams) | |
all_system_ngrams += len(system_ngrams) | |
return all_matched_ngrams / all_system_ngrams | |
def closest_min_length(references_sentence, system_sentence): | |
pairs = [] | |
for reference_sentence in references_sentence: | |
closeness = abs(len(reference_sentence) - len(system_sentence)) | |
pairs.append((closeness, len(reference_sentence))) | |
# Python trick: min compares second items if first items were same, so it chooses shorter one when even. | |
return min(pairs)[1] | |
def brevity_penalty(references, system): | |
reference_length = 0 | |
for references_sentence, system_sentence in zip(references, system): | |
reference_length += closest_min_length(references_sentence, system_sentence) | |
return min(1.0, math.exp(1.0 - reference_length / sum(len(system_sentence) for system_sentence in system))) | |
def bleu(references, system, order=4): | |
result = 1.0 | |
# Note: n==0 means unigram, n==1 means bigram etc | |
for n in range(order): | |
result *= precision(references, system, n) ** (1.0 / order) | |
return result * brevity_penalty(references, system) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('system', type=argparse.FileType('r')) | |
parser.add_argument('references', nargs='+', type=argparse.FileType('r')) | |
parser.add_argument('--order', type=int, default=4) | |
args = parser.parse_args() | |
system = [sentence.split(' ') for sentence in args.system] | |
references = [] | |
for references_sentence in zip(*args.references): | |
references.append([sentence.split(' ') for sentence in references_sentence]) | |
print('BLEU:', bleu(references, system, args.order)) | |
print('Brevity Penalty:', brevity_penalty(references, system)) | |
for n in range(args.order): | |
print('{}-gram precision:'.format(n+1), precision(references, system, n)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import unittest | |
import bleu | |
class Test(unittest.TestCase): | |
def setUp(self): | |
self.system = ["Well , I 'd like to stay five nights beginning June sixth .".split(' ')] | |
self.references = [[ | |
"I 'd like to stay there for five nights , from June sixth .".split(' '), | |
"I want to stay for five nights , from June sixth .".split(' '), | |
"I 'd like to stay for five nights from June sixth .".split(' '), | |
"I would like to reserve a room for five nights from June sixth .".split(' ')]] | |
def test_ngrams(self): | |
sentence = ['I', 'am', 'a', 'beautiful', 'person'] | |
actual = list(bleu.ngrams(sentence, 1)) | |
expected = [('I', 'am'), ('am', 'a'), ('a', 'beautiful'), ('beautiful', 'person')] | |
self.assertEqual(actual, expected) | |
def test_precision(self): | |
self.assertAlmostEqual(bleu.precision(self.references, self.system, 0), 11/13) | |
self.assertAlmostEqual(bleu.precision(self.references, self.system, 1), 7/12) | |
self.assertAlmostEqual(bleu.precision(self.references, self.system, 2), 4/11) | |
self.assertAlmostEqual(bleu.precision(self.references, self.system, 3), 2/10) | |
def test_brevity_penalty(self): | |
self.assertAlmostEqual(bleu.brevity_penalty(self.references, self.system), 1) | |
def test_bleu(self): | |
actual = bleu.bleu(self.references, self.system) | |
expected = 0.4353 | |
self.assertAlmostEqual(actual, expected, places=4) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment