Created
August 26, 2019 16:39
-
-
Save ilyarudyak/148efb9d438484cdcd3064a27c22dd95 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 122, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"The autoreload extension is already loaded. To reload it, use:\n", | |
" %reload_ext autoreload\n" | |
] | |
} | |
], | |
"source": [ | |
"from nltk.util import ngrams\n", | |
"from collections import Counter\n", | |
"from fractions import Fraction\n", | |
"from nltk.translate.bleu_score import sentence_bleu, closest_ref_length, brevity_penalty\n", | |
"import numpy as np\n", | |
"import math\n", | |
"\n", | |
"%load_ext autoreload\n", | |
"%autoreload 2" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# BLEU score" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## example from the original paper" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### data" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"First let's get some data from this paper (reference and candidate translations)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"rt_raw = [\n", | |
" 'It is a guide to action that ensures that the military will forever heed Party commands',\n", | |
" 'It is the guiding principle which guarantees the military forces always being under the command of the Party',\n", | |
" 'It is the practical guide for the army always to heed the directions of the party'\n", | |
"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ct_raw = [\n", | |
" 'It is a guide to action which ensures that the military always obeys the commands of the party',\n", | |
" 'It is to insure the troops forever hearing the activity guidebook that party direct'\n", | |
"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def process_trans(t):\n", | |
" return t.lower().split()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"rt = [process_trans(t) for t in rt_raw]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ct = [process_trans(t) for t in ct_raw]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"c1, c2 = ct[0], ct[1]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### modified unigram precision" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_unigram_count_clip(u, c_count, rt_counts):\n", | |
" max_ref_count = max([rt_count[u] for rt_count in rt_counts])\n", | |
" return min(c_count[u], max_ref_count)\n", | |
"\n", | |
"def get_unigram_modified_precision(c, rt):\n", | |
" c_count = Counter(ngrams(c, 1))\n", | |
" rt_counts = [Counter(ngrams(r, 1)) for r in rt]\n", | |
" clipped_counts = sum([get_unigram_count_clip(u, c_count, rt_counts) for u in c_count])\n", | |
" total_counts = len(c)\n", | |
" return Fraction(clipped_counts, total_counts)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 39, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"17/18 4/7\n" | |
] | |
} | |
], | |
"source": [ | |
"print(get_unigram_modified_precision(c1, rt), get_unigram_modified_precision(c2, rt))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### modified ngram precision" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 51, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_count_clip(u, c_count, rt_counts):\n", | |
" max_ref_count = max([rt_count[u] for rt_count in rt_counts])\n", | |
" return min(c_count[u], max_ref_count)\n", | |
"\n", | |
"def get_mp(c, rt, n):\n", | |
" mps = []\n", | |
" for i in range(1, n+1):\n", | |
" c_count = Counter(ngrams(c, i))\n", | |
" rt_counts = [Counter(ngrams(r, i)) for r in rt]\n", | |
" clipped_counts = sum([get_count_clip(u, c_count, rt_counts) for u in c_count])\n", | |
" total_counts = sum(c_count.values())\n", | |
" mps.append(Fraction(clipped_counts, total_counts))\n", | |
" return mps" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 52, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"([Fraction(17, 18)], [Fraction(4, 7)])" | |
] | |
}, | |
"execution_count": 52, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"get_mp(c1, rt, 1), get_mp(c2, rt, 1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 53, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"([Fraction(17, 18), Fraction(10, 17)], [Fraction(4, 7), Fraction(1, 13)])" | |
] | |
}, | |
"execution_count": 53, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"get_mp(c1, rt, 2), get_mp(c2, rt, 2)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### brevity penalty" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"It's not quite clear from this article how to compute brevity penalty. We may use `nltk` algorithm (which is quite clear from its source [code](https://www.nltk.org/_modules/nltk/translate/bleu_score.html)). First we compute closest reference solution **`by length`** with `min length` (the last part is important). So in case of `c1` the closest len is `18`, in case of `c2` - `16`. So in case of `c1` we don't have brevity penalty." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 110, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"([16, 18, 16], 18, 14)" | |
] | |
}, | |
"execution_count": 110, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"[len(r) for r in rt], len(c1), len(c2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 117, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"c1_closest, c2_closest = closest_ref_length(rt, 18), closest_ref_length(rt, len(c2))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 118, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(18, 16)" | |
] | |
}, | |
"execution_count": 118, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"c1_closest, c2_closest" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 119, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"1.0" | |
] | |
}, | |
"execution_count": 119, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"brevity_penalty(c1_closest, len(c1))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 120, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.8668778997501817" | |
] | |
}, | |
"execution_count": 120, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"brevity_penalty(c2_closest, len(c2))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 123, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.8668778997501817" | |
] | |
}, | |
"execution_count": 123, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"math.exp(1 - c2_closest / len(c2))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### BLEU score" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's finally compute the score and compare it with `nltk`." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### for `c1`" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 141, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.7453559924999299" | |
] | |
}, | |
"execution_count": 141, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sentence_bleu(rt, c1, weights=(.5, .5, 0, 0))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 129, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"p1, p2 = get_mp(c1, rt, 2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 130, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(Fraction(17, 18), Fraction(10, 17))" | |
] | |
}, | |
"execution_count": 130, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"p1, p2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 131, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.9444444444444444" | |
] | |
}, | |
"execution_count": 131, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"float(p1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 134, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.7453559924999299" | |
] | |
}, | |
"execution_count": 134, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"math.exp(.5 * math.log(float(p1)) + .5 * math.log(float(p2)))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### for `c2`" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 143, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.18174699151949172" | |
] | |
}, | |
"execution_count": 143, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sentence_bleu(rt, c2, weights=(.5, .5, 0, 0))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 136, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"BP = brevity_penalty(c2_closest, len(c2))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 138, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"p1, p2 = get_mp(c2, rt, 2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 139, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(Fraction(4, 7), Fraction(1, 13))" | |
] | |
}, | |
"execution_count": 139, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"p1, p2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 140, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.18174699151949172" | |
] | |
}, | |
"execution_count": 140, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"BP * math.exp(.5 * math.log(float(p1)) + .5 * math.log(float(p2)))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"This concludes our analysis of `BLEU` score." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment