Skip to content

Instantly share code, notes, and snippets.

@ilyarudyak
Created August 26, 2019 16:39
Show Gist options
  • Save ilyarudyak/148efb9d438484cdcd3064a27c22dd95 to your computer and use it in GitHub Desktop.
Save ilyarudyak/148efb9d438484cdcd3064a27c22dd95 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 122,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"from nltk.util import ngrams\n",
"from collections import Counter\n",
"from fractions import Fraction\n",
"from nltk.translate.bleu_score import sentence_bleu, closest_ref_length, brevity_penalty\n",
"import numpy as np\n",
"import math\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# BLEU score"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## example from the original paper"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First let's get some data from this paper (reference and candidate translations)."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"rt_raw = [\n",
" 'It is a guide to action that ensures that the military will forever heed Party commands',\n",
" 'It is the guiding principle which guarantees the military forces always being under the command of the Party',\n",
" 'It is the practical guide for the army always to heed the directions of the party'\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"ct_raw = [\n",
" 'It is a guide to action which ensures that the military always obeys the commands of the party',\n",
" 'It is to insure the troops forever hearing the activity guidebook that party direct'\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def process_trans(t):\n",
" return t.lower().split()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"rt = [process_trans(t) for t in rt_raw]"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"ct = [process_trans(t) for t in ct_raw]"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"c1, c2 = ct[0], ct[1]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### modified unigram precision"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"def get_unigram_count_clip(u, c_count, rt_counts):\n",
" max_ref_count = max([rt_count[u] for rt_count in rt_counts])\n",
" return min(c_count[u], max_ref_count)\n",
"\n",
"def get_unigram_modified_precision(c, rt):\n",
" c_count = Counter(ngrams(c, 1))\n",
" rt_counts = [Counter(ngrams(r, 1)) for r in rt]\n",
" clipped_counts = sum([get_unigram_count_clip(u, c_count, rt_counts) for u in c_count])\n",
" total_counts = len(c)\n",
" return Fraction(clipped_counts, total_counts)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"17/18 4/7\n"
]
}
],
"source": [
"print(get_unigram_modified_precision(c1, rt), get_unigram_modified_precision(c2, rt))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### modified ngram precision"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"def get_count_clip(u, c_count, rt_counts):\n",
" max_ref_count = max([rt_count[u] for rt_count in rt_counts])\n",
" return min(c_count[u], max_ref_count)\n",
"\n",
"def get_mp(c, rt, n):\n",
" mps = []\n",
" for i in range(1, n+1):\n",
" c_count = Counter(ngrams(c, i))\n",
" rt_counts = [Counter(ngrams(r, i)) for r in rt]\n",
" clipped_counts = sum([get_count_clip(u, c_count, rt_counts) for u in c_count])\n",
" total_counts = sum(c_count.values())\n",
" mps.append(Fraction(clipped_counts, total_counts))\n",
" return mps"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"([Fraction(17, 18)], [Fraction(4, 7)])"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_mp(c1, rt, 1), get_mp(c2, rt, 1)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"([Fraction(17, 18), Fraction(10, 17)], [Fraction(4, 7), Fraction(1, 13)])"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_mp(c1, rt, 2), get_mp(c2, rt, 2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### brevity penalty"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It's not quite clear from this article how to compute brevity penalty. We may use `nltk` algorithm (which is quite clear from its source [code](https://www.nltk.org/_modules/nltk/translate/bleu_score.html)). First we compute closest reference solution **`by length`** with `min length` (the last part is important). So in case of `c1` the closest len is `18`, in case of `c2` - `16`. So in case of `c1` we don't have brevity penalty."
]
},
{
"cell_type": "code",
"execution_count": 110,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"([16, 18, 16], 18, 14)"
]
},
"execution_count": 110,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[len(r) for r in rt], len(c1), len(c2)"
]
},
{
"cell_type": "code",
"execution_count": 117,
"metadata": {},
"outputs": [],
"source": [
"c1_closest, c2_closest = closest_ref_length(rt, 18), closest_ref_length(rt, len(c2))"
]
},
{
"cell_type": "code",
"execution_count": 118,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(18, 16)"
]
},
"execution_count": 118,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c1_closest, c2_closest"
]
},
{
"cell_type": "code",
"execution_count": 119,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.0"
]
},
"execution_count": 119,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"brevity_penalty(c1_closest, len(c1))"
]
},
{
"cell_type": "code",
"execution_count": 120,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.8668778997501817"
]
},
"execution_count": 120,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"brevity_penalty(c2_closest, len(c2))"
]
},
{
"cell_type": "code",
"execution_count": 123,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.8668778997501817"
]
},
"execution_count": 123,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"math.exp(1 - c2_closest / len(c2))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### BLEU score"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's finally compute the score and compare it with `nltk`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### for `c1`"
]
},
{
"cell_type": "code",
"execution_count": 141,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.7453559924999299"
]
},
"execution_count": 141,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sentence_bleu(rt, c1, weights=(.5, .5, 0, 0))"
]
},
{
"cell_type": "code",
"execution_count": 129,
"metadata": {},
"outputs": [],
"source": [
"p1, p2 = get_mp(c1, rt, 2)"
]
},
{
"cell_type": "code",
"execution_count": 130,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Fraction(17, 18), Fraction(10, 17))"
]
},
"execution_count": 130,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"p1, p2"
]
},
{
"cell_type": "code",
"execution_count": 131,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9444444444444444"
]
},
"execution_count": 131,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"float(p1)"
]
},
{
"cell_type": "code",
"execution_count": 134,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.7453559924999299"
]
},
"execution_count": 134,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"math.exp(.5 * math.log(float(p1)) + .5 * math.log(float(p2)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### for `c2`"
]
},
{
"cell_type": "code",
"execution_count": 143,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.18174699151949172"
]
},
"execution_count": 143,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sentence_bleu(rt, c2, weights=(.5, .5, 0, 0))"
]
},
{
"cell_type": "code",
"execution_count": 136,
"metadata": {},
"outputs": [],
"source": [
"BP = brevity_penalty(c2_closest, len(c2))"
]
},
{
"cell_type": "code",
"execution_count": 138,
"metadata": {},
"outputs": [],
"source": [
"p1, p2 = get_mp(c2, rt, 2)"
]
},
{
"cell_type": "code",
"execution_count": 139,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Fraction(4, 7), Fraction(1, 13))"
]
},
"execution_count": 139,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"p1, p2"
]
},
{
"cell_type": "code",
"execution_count": 140,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.18174699151949172"
]
},
"execution_count": 140,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"BP * math.exp(.5 * math.log(float(p1)) + .5 * math.log(float(p2)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This concludes our analysis of `BLEU` score."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment