Skip to content

Instantly share code, notes, and snippets.

@EdwardJRoss
Created January 14, 2019 12:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save EdwardJRoss/86b31848a7951411de56f10f55e9de4e to your computer and use it in GitHub Desktop.
Save EdwardJRoss/86b31848a7951411de56f10f55e9de4e to your computer and use it in GitHub Desktop.
Character Level Classification of Surname Ethnicity using Fastai
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Classifying Name Ethnicity with a Character level RNN"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is heavily modeled on the Pytorch tutorial:\n",
"https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html\n",
"\n",
"We use fastai libraries extensively to make dataloading and training easier"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Download the Pytorch tutorial data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is a list of surnames and their ethnicities"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"#!wget https://download.pytorch.org/tutorial/data.zip"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"#!unzip -o data.zip"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the data\n",
"fastai import pandas and all sorts of other goodies"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from fastai import *\n",
"from fastai.text import *"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from unidecode import unidecode\n",
"import string"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Reduce the ouput to 20 rows to prevent it from taking too much of the output."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"pd.options.display.max_rows = 20"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Read in the data; names for each language is in a separate file"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"path = Path('data/names')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Arabic.txt English.txt Irish.txt\tPolish.txt\tSpanish.txt\r\n",
"Chinese.txt French.txt Italian.txt\tPortuguese.txt\tVietnamese.txt\r\n",
"Czech.txt German.txt Japanese.txt\tRussian.txt\r\n",
"Dutch.txt Greek.txt\t Korean.txt\tScottish.txt\r\n"
]
}
],
"source": [
"!ls {path}"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Khoury\r\n",
"Nahas\r\n",
"Daher\r\n",
"Gerges\r\n",
"Nazari\r\n"
]
}
],
"source": [
"!head -n5 {path}/Arabic.txt"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"names = []\n",
"for p in path.glob('*.txt'):\n",
" lang = p.name[:-4]\n",
" with open(p) as f:\n",
" names += [(lang, l.strip()) for l in f]\n",
"df = pd.DataFrame(names, columns=['cl', 'name'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Check the Data\n",
"It's always worth doing some sanity checks on your data (even supposedly clean tutorial data).\n",
"\n",
"No matter how good your model is: garbage in, garbage out."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cl</th>\n",
" <th>name</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Korean</td>\n",
" <td>Ahn</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Korean</td>\n",
" <td>Baik</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Korean</td>\n",
" <td>Bang</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Korean</td>\n",
" <td>Byon</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Korean</td>\n",
" <td>Cha</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" cl name\n",
"0 Korean Ahn\n",
"1 Korean Baik\n",
"2 Korean Bang\n",
"3 Korean Byon\n",
"4 Korean Cha"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"20074"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Character Set\n",
"What letters outside of ASCII are in the names?"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[(' ', 115),\n",
" (\"'\", 87),\n",
" ('-', 25),\n",
" ('ö', 24),\n",
" ('é', 22),\n",
" ('í', 14),\n",
" ('ó', 13),\n",
" ('ä', 13),\n",
" ('á', 12),\n",
" ('ü', 11),\n",
" ('à', 10),\n",
" ('ß', 9),\n",
" ('ú', 7),\n",
" ('ñ', 6),\n",
" ('ò', 3),\n",
" ('Ś', 3),\n",
" ('1', 3),\n",
" (',', 3),\n",
" ('è', 2),\n",
" ('ã', 2),\n",
" ('ù', 1),\n",
" ('ì', 1),\n",
" ('ż', 1),\n",
" ('ń', 1),\n",
" ('ł', 1),\n",
" ('ą', 1),\n",
" ('Ż', 1),\n",
" ('/', 1),\n",
" (':', 1),\n",
" ('Á', 1),\n",
" ('\\xa0', 1),\n",
" ('õ', 1),\n",
" ('É', 1),\n",
" ('ê', 1),\n",
" ('ç', 1)]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"foreign_chars = Counter(_ for _ in ''.join(list(df.name)) if _ not in string.ascii_letters)\n",
"foreign_chars.most_common()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A few of these look suspicious. (Note the use of a regular expression in `contains` to check each of the characters)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cl</th>\n",
" <th>name</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2494</th>\n",
" <td>Czech</td>\n",
" <td>Maxa/B</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2590</th>\n",
" <td>Czech</td>\n",
" <td>Rafaj1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2703</th>\n",
" <td>Czech</td>\n",
" <td>Urbanek1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2732</th>\n",
" <td>Czech</td>\n",
" <td>Whitmire1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3214</th>\n",
" <td>Chinese</td>\n",
" <td>Lu:</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14506</th>\n",
" <td>Russian</td>\n",
" <td>Jevolojnov,</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15347</th>\n",
" <td>Russian</td>\n",
" <td>Lysansky,</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15366</th>\n",
" <td>Russian</td>\n",
" <td>Lytkin,</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18052</th>\n",
" <td>Russian</td>\n",
" <td>To The First  Page</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" cl name\n",
"2494 Czech Maxa/B\n",
"2590 Czech Rafaj1\n",
"2703 Czech Urbanek1\n",
"2732 Czech Whitmire1\n",
"3214 Chinese Lu:\n",
"14506 Russian Jevolojnov,\n",
"15347 Russian Lysansky,\n",
"15366 Russian Lytkin,\n",
"18052 Russian To The First  Page"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"suss_chars = [':', '/', '\\xa0', ',', '1']\n",
"df[df.name.str.contains('|'.join(suss_chars))]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Most of these look like legitimate names with extra junk (except 'To The First Page').\n",
"Since it's so few names it's easiest just to drop them."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"df = df[~df.name.str.contains('|'.join(suss_chars))]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Single quotes and spaces are common"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cl</th>\n",
" <th>name</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>369</th>\n",
" <td>Italian</td>\n",
" <td>D'ambrosio</td>\n",
" </tr>\n",
" <tr>\n",
" <th>371</th>\n",
" <td>Italian</td>\n",
" <td>D'amore</td>\n",
" </tr>\n",
" <tr>\n",
" <th>372</th>\n",
" <td>Italian</td>\n",
" <td>D'angelo</td>\n",
" </tr>\n",
" <tr>\n",
" <th>373</th>\n",
" <td>Italian</td>\n",
" <td>D'antonio</td>\n",
" </tr>\n",
" <tr>\n",
" <th>374</th>\n",
" <td>Italian</td>\n",
" <td>De angelis</td>\n",
" </tr>\n",
" <tr>\n",
" <th>375</th>\n",
" <td>Italian</td>\n",
" <td>De campo</td>\n",
" </tr>\n",
" <tr>\n",
" <th>376</th>\n",
" <td>Italian</td>\n",
" <td>De felice</td>\n",
" </tr>\n",
" <tr>\n",
" <th>377</th>\n",
" <td>Italian</td>\n",
" <td>De filippis</td>\n",
" </tr>\n",
" <tr>\n",
" <th>378</th>\n",
" <td>Italian</td>\n",
" <td>De fiore</td>\n",
" </tr>\n",
" <tr>\n",
" <th>379</th>\n",
" <td>Italian</td>\n",
" <td>De laurentis</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18161</th>\n",
" <td>Russian</td>\n",
" <td>V'Yurkov</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19061</th>\n",
" <td>Russian</td>\n",
" <td>Zasyad'Ko</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19740</th>\n",
" <td>Portuguese</td>\n",
" <td>D'cruz</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19741</th>\n",
" <td>Portuguese</td>\n",
" <td>D'cruze</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19743</th>\n",
" <td>Portuguese</td>\n",
" <td>De santigo</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19858</th>\n",
" <td>French</td>\n",
" <td>D'aramitz</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19864</th>\n",
" <td>French</td>\n",
" <td>De la fontaine</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19869</th>\n",
" <td>French</td>\n",
" <td>De sauveterre</td>\n",
" </tr>\n",
" <tr>\n",
" <th>20051</th>\n",
" <td>French</td>\n",
" <td>St martin</td>\n",
" </tr>\n",
" <tr>\n",
" <th>20052</th>\n",
" <td>French</td>\n",
" <td>St pierre</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>150 rows × 2 columns</p>\n",
"</div>"
],
"text/plain": [
" cl name\n",
"369 Italian D'ambrosio\n",
"371 Italian D'amore\n",
"372 Italian D'angelo\n",
"373 Italian D'antonio\n",
"374 Italian De angelis\n",
"375 Italian De campo\n",
"376 Italian De felice\n",
"377 Italian De filippis\n",
"378 Italian De fiore\n",
"379 Italian De laurentis\n",
"... ... ...\n",
"18161 Russian V'Yurkov\n",
"19061 Russian Zasyad'Ko\n",
"19740 Portuguese D'cruz\n",
"19741 Portuguese D'cruze\n",
"19743 Portuguese De santigo\n",
"19858 French D'aramitz\n",
"19864 French De la fontaine\n",
"19869 French De sauveterre\n",
"20051 French St martin\n",
"20052 French St pierre\n",
"\n",
"[150 rows x 2 columns]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df[df.name.str.contains(\"'| \")]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Since hyphens mainly join multiple last names (and are pretty rare) we won't lose heaps by dropping them."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cl</th>\n",
" <th>name</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2982</th>\n",
" <td>Chinese</td>\n",
" <td>Au-Yong</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3088</th>\n",
" <td>Chinese</td>\n",
" <td>Ou-Yang</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3089</th>\n",
" <td>Chinese</td>\n",
" <td>Ow-Yang</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10156</th>\n",
" <td>Russian</td>\n",
" <td>Abdank-Kossovsky</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10639</th>\n",
" <td>Russian</td>\n",
" <td>Amet-Han</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11221</th>\n",
" <td>Russian</td>\n",
" <td>Bagai-Ool</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11757</th>\n",
" <td>Russian</td>\n",
" <td>Bei-Bienko</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11787</th>\n",
" <td>Russian</td>\n",
" <td>Beknazar-Yuzbashev</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11790</th>\n",
" <td>Russian</td>\n",
" <td>Bekovich-Cherkassky</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11904</th>\n",
" <td>Russian</td>\n",
" <td>Bestujev-Lada</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11952</th>\n",
" <td>Russian</td>\n",
" <td>Bim-Bad</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12209</th>\n",
" <td>Russian</td>\n",
" <td>Chyrgal-Ool</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13071</th>\n",
" <td>Russian</td>\n",
" <td>Galkin-Vraskoi</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13307</th>\n",
" <td>Russian</td>\n",
" <td>Gorbunov-Posadov</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16687</th>\n",
" <td>Russian</td>\n",
" <td>Porai-Koshits</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17222</th>\n",
" <td>Russian</td>\n",
" <td>Shah-Nazaroff</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17430</th>\n",
" <td>Russian</td>\n",
" <td>Shirinsky-Shikhmatov</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17748</th>\n",
" <td>Russian</td>\n",
" <td>Tsann-Kay-Si</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17999</th>\n",
" <td>Russian</td>\n",
" <td>Tzann-Kay-Si</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18315</th>\n",
" <td>Russian</td>\n",
" <td>Van-Puteren</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>23 rows × 2 columns</p>\n",
"</div>"
],
"text/plain": [
" cl name\n",
"2982 Chinese Au-Yong\n",
"3088 Chinese Ou-Yang\n",
"3089 Chinese Ow-Yang\n",
"10156 Russian Abdank-Kossovsky\n",
"10639 Russian Amet-Han\n",
"11221 Russian Bagai-Ool\n",
"11757 Russian Bei-Bienko\n",
"11787 Russian Beknazar-Yuzbashev\n",
"11790 Russian Bekovich-Cherkassky\n",
"11904 Russian Bestujev-Lada\n",
"... ... ...\n",
"11952 Russian Bim-Bad\n",
"12209 Russian Chyrgal-Ool\n",
"13071 Russian Galkin-Vraskoi\n",
"13307 Russian Gorbunov-Posadov\n",
"16687 Russian Porai-Koshits\n",
"17222 Russian Shah-Nazaroff\n",
"17430 Russian Shirinsky-Shikhmatov\n",
"17748 Russian Tsann-Kay-Si\n",
"17999 Russian Tzann-Kay-Si\n",
"18315 Russian Van-Puteren\n",
"\n",
"[23 rows x 2 columns]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df[df.name.str.contains('-')]"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"df = df[~df.name.str.contains('-')]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Normalising non-ASCII Characters"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's normalise all non-ASCII characters to ASCII equivalents.\n",
"\n",
"This makes our classification problem harder in practice: any names containing a ß are almost surely German, wheras \"ss\" could occur in many language. It also reduces the set of characters we need to represent our language."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cl</th>\n",
" <th>name</th>\n",
" <th>ascii_name</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>100</th>\n",
" <td>Italian</td>\n",
" <td>Abbà</td>\n",
" <td>Abba</td>\n",
" </tr>\n",
" <tr>\n",
" <th>112</th>\n",
" <td>Italian</td>\n",
" <td>Abelló</td>\n",
" <td>Abello</td>\n",
" </tr>\n",
" <tr>\n",
" <th>160</th>\n",
" <td>Italian</td>\n",
" <td>Airò</td>\n",
" <td>Airo</td>\n",
" </tr>\n",
" <tr>\n",
" <th>195</th>\n",
" <td>Italian</td>\n",
" <td>Alò</td>\n",
" <td>Alo</td>\n",
" </tr>\n",
" <tr>\n",
" <th>238</th>\n",
" <td>Italian</td>\n",
" <td>Azzarà</td>\n",
" <td>Azzara</td>\n",
" </tr>\n",
" <tr>\n",
" <th>300</th>\n",
" <td>Italian</td>\n",
" <td>Bovér</td>\n",
" <td>Bover</td>\n",
" </tr>\n",
" <tr>\n",
" <th>445</th>\n",
" <td>Italian</td>\n",
" <td>Giùgovaz</td>\n",
" <td>Giugovaz</td>\n",
" </tr>\n",
" <tr>\n",
" <th>461</th>\n",
" <td>Italian</td>\n",
" <td>Làconi</td>\n",
" <td>Laconi</td>\n",
" </tr>\n",
" <tr>\n",
" <th>462</th>\n",
" <td>Italian</td>\n",
" <td>Laganà</td>\n",
" <td>Lagana</td>\n",
" </tr>\n",
" <tr>\n",
" <th>463</th>\n",
" <td>Italian</td>\n",
" <td>Lagomarsìno</td>\n",
" <td>Lagomarsino</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19912</th>\n",
" <td>French</td>\n",
" <td>Géroux</td>\n",
" <td>Geroux</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19920</th>\n",
" <td>French</td>\n",
" <td>Guérin</td>\n",
" <td>Guerin</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19924</th>\n",
" <td>French</td>\n",
" <td>Hébert</td>\n",
" <td>Hebert</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19949</th>\n",
" <td>French</td>\n",
" <td>Lécuyer</td>\n",
" <td>Lecuyer</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19951</th>\n",
" <td>French</td>\n",
" <td>Lefévre</td>\n",
" <td>Lefevre</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19955</th>\n",
" <td>French</td>\n",
" <td>Lémieux</td>\n",
" <td>Lemieux</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19960</th>\n",
" <td>French</td>\n",
" <td>Lévêque</td>\n",
" <td>Leveque</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19961</th>\n",
" <td>French</td>\n",
" <td>Lévesque</td>\n",
" <td>Levesque</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19965</th>\n",
" <td>French</td>\n",
" <td>Maçon</td>\n",
" <td>Macon</td>\n",
" </tr>\n",
" <tr>\n",
" <th>20047</th>\n",
" <td>French</td>\n",
" <td>Séverin</td>\n",
" <td>Severin</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>156 rows × 3 columns</p>\n",
"</div>"
],
"text/plain": [
" cl name ascii_name\n",
"100 Italian Abbà Abba\n",
"112 Italian Abelló Abello\n",
"160 Italian Airò Airo\n",
"195 Italian Alò Alo\n",
"238 Italian Azzarà Azzara\n",
"300 Italian Bovér Bover\n",
"445 Italian Giùgovaz Giugovaz\n",
"461 Italian Làconi Laconi\n",
"462 Italian Laganà Lagana\n",
"463 Italian Lagomarsìno Lagomarsino\n",
"... ... ... ...\n",
"19912 French Géroux Geroux\n",
"19920 French Guérin Guerin\n",
"19924 French Hébert Hebert\n",
"19949 French Lécuyer Lecuyer\n",
"19951 French Lefévre Lefevre\n",
"19955 French Lémieux Lemieux\n",
"19960 French Lévêque Leveque\n",
"19961 French Lévesque Levesque\n",
"19965 French Maçon Macon\n",
"20047 French Séverin Severin\n",
"\n",
"[156 rows x 3 columns]"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df['ascii_name'] = df.name.apply(unidecode)\n",
"df[df.name != df.ascii_name]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's check case: I expect names to be in CamelCase.\n",
"\n",
"These seem to be mistakes."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cl</th>\n",
" <th>name</th>\n",
" <th>ascii_name</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2508</th>\n",
" <td>Czech</td>\n",
" <td>MonkoAustria</td>\n",
" <td>MonkoAustria</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2677</th>\n",
" <td>Czech</td>\n",
" <td>StrakaO</td>\n",
" <td>StrakaO</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3266</th>\n",
" <td>Vietnamese</td>\n",
" <td>an</td>\n",
" <td>an</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" cl name ascii_name\n",
"2508 Czech MonkoAustria MonkoAustria\n",
"2677 Czech StrakaO StrakaO\n",
"3266 Vietnamese an an"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df[~df.ascii_name.str.contains(\"^[A-Z][^A-Z]*(?:[' -][A-Z][^A-Z]*)*$\")]"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"df = df[df.ascii_name.str.contains(\"^[A-Z][^A-Z]*(?:[' -][A-Z][^A-Z]*)*$\")]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's lowercase the ascii_names"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"df['ascii_name'] = df.ascii_name.str.lower()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Make a check we've normalised correctly."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('a', 16511),\n",
" ('o', 11120),\n",
" ('e', 10768),\n",
" ('i', 10416),\n",
" ('n', 9943),\n",
" ('r', 8245),\n",
" ('s', 7980),\n",
" ('h', 7673),\n",
" ('k', 6902),\n",
" ('l', 6704),\n",
" ('v', 6301),\n",
" ('t', 5939),\n",
" ('u', 4725),\n",
" ('m', 4343),\n",
" ('d', 3894),\n",
" ('b', 3641),\n",
" ('y', 3604),\n",
" ('g', 3209),\n",
" ('c', 3068),\n",
" ('z', 1928),\n",
" ('f', 1774),\n",
" ('p', 1707),\n",
" ('j', 1346),\n",
" ('w', 1125),\n",
" (' ', 112),\n",
" ('q', 98),\n",
" (\"'\", 87),\n",
" ('x', 72)]"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ascii_chars = Counter(''.join(list(df.ascii_name)))\n",
"ascii_chars.most_common()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### How many classes does each name have?\n",
"\n",
"In practice a surname could have multiple ethnicities, but we'd have to be really careful of how we use this in training.\n",
"\n",
"If we end up with e.g. 'Michel' as French in the training dataset, but German in the validation set our model has no hope of getting it right (and we may discard an actually good model).\n",
"\n",
"We could handle this by:\n",
"1) Allowing multiple class labels\n",
"2) Picking the country that the name most commonly associates to\n",
"3) Dropping ambiguous cases\n",
"\n",
"Without any information about frequency we can't do (2) and (1) is a harder problem, so we'll stick to (3)."
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ascii_name\n",
"michel 6\n",
"adam 5\n",
"albert 5\n",
"abel 5\n",
"martin 5\n",
"simon 5\n",
"ventura 4\n",
"costa 4\n",
"jordan 4\n",
"han 4\n",
"salomon 4\n",
"samuel 4\n",
"klein 4\n",
"franco 4\n",
"wang 4\n",
"oliver 4\n",
"garcia 3\n",
"horn 3\n",
"lim 3\n",
"rose 3\n",
"Name: cl, dtype: int64"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"name_classes = df.\\\n",
" groupby('ascii_name').\\\n",
" nunique().cl.sort_values(ascending=False)\n",
"name_classes.head(20)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cl</th>\n",
" <th>name</th>\n",
" <th>ascii_name</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>872</th>\n",
" <td>Polish</td>\n",
" <td>Michel</td>\n",
" <td>michel</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2077</th>\n",
" <td>Dutch</td>\n",
" <td>Michel</td>\n",
" <td>michel</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3489</th>\n",
" <td>Spanish</td>\n",
" <td>Michel</td>\n",
" <td>michel</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6163</th>\n",
" <td>German</td>\n",
" <td>Michel</td>\n",
" <td>michel</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8709</th>\n",
" <td>English</td>\n",
" <td>Michel</td>\n",
" <td>michel</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19978</th>\n",
" <td>French</td>\n",
" <td>Michel</td>\n",
" <td>michel</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" cl name ascii_name\n",
"872 Polish Michel michel\n",
"2077 Dutch Michel michel\n",
"3489 Spanish Michel michel\n",
"6163 German Michel michel\n",
"8709 English Michel michel\n",
"19978 French Michel michel"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df[df.name == 'Michel']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1 in 40 of our names have multiple classes (most of them do before normalisation too)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(17380, 0.027445339470655927)"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(name_classes), sum(name_classes > 1) / len(name_classes)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"While some names like [Abel](https://en.wikipedia.org/wiki/Abel_(surname)) do seem to occur commonly in multiple countries, for example:\n",
"- [Adamson](https://en.wikipedia.org/wiki/Adamson_(surname)) is very unlikely to be Russian\n",
"- [Wong](https://en.wikipedia.org/wiki/Wong_(surname)) is much more prevalant in Chinese than English\n",
"- [Yang](https://en.wikipedia.org/wiki/Yang_(surname)) is very rare in English\n",
"\n",
"It seems like Korean and Chinese have a lot of overlap, as to English and Scottish.\n",
"While this makes some linguistic sense it will make it hard to make a reliable classifier.\n",
"\n",
"Note that most names only occur once; so we can't pick a \"most common\" frequency class."
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" name\n",
"ascii_name cl \n",
"abel English 1\n",
" French 1\n",
" German 1\n",
" Russian 1\n",
" Spanish 1\n",
"abello Italian 1\n",
" Spanish 1\n",
"abraham English 1\n",
" French 1\n",
"abreu Portuguese 1\n",
" Spanish 1\n",
"adam English 1\n",
" French 1\n",
" German 1\n",
" Irish 1\n",
" Russian 1\n",
"adams English 1\n",
" Russian 1\n",
"adamson English 1\n",
" Russian 1\n",
"adler English 1\n",
" German 1\n",
" Russian 1\n",
"aitken English 1\n",
" Scottish 1\n",
"albert English 1\n",
" French 1\n",
" German 1\n",
" Russian 1\n",
" Spanish 1\n",
"... ...\n",
"wilson English 1\n",
" Scottish 1\n",
"winter English 1\n",
" German 1\n",
"wolf English 1\n",
" German 1\n",
"wong Chinese 1\n",
" English 1\n",
"woo Chinese 1\n",
" Korean 1\n",
"wood Czech 1\n",
" English 1\n",
" Scottish 1\n",
"wright English 1\n",
" Scottish 1\n",
"yan Chinese 2\n",
" Russian 1\n",
"yang Chinese 1\n",
" English 1\n",
" Korean 1\n",
"yim Chinese 1\n",
" Korean 1\n",
"you Chinese 1\n",
" Korean 1\n",
"young English 1\n",
" Scottish 1\n",
"yun Chinese 1\n",
" Korean 1\n",
"zambrano Italian 1\n",
" Spanish 1\n",
"\n",
"[1051 rows x 1 columns]\n"
]
}
],
"source": [
"with pd.option_context('display.max_rows', 60):\n",
" print(df[df.ascii_name.isin(name_classes[name_classes > 1].index)].groupby(['ascii_name', 'cl']).count())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Rather than finding the \"right\" ethnicity the easy thing to do is to remove all ambiguous cases."
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"df = df[~df.ascii_name.isin(name_classes[name_classes > 1].index)]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### How often do (class, name) pairs occur?\n",
"\n",
"We need exactly one row per pair; if separate copies appear in the training and validation set we'll get a higher validation accuracy than is reasonable."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Some names occur very frequently."
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th></th>\n",
" <th>name</th>\n",
" <th>n</th>\n",
" </tr>\n",
" <tr>\n",
" <th>ascii_name</th>\n",
" <th>cl</th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>tahan</th>\n",
" <th>Arabic</th>\n",
" <td>28</td>\n",
" <td>28</td>\n",
" </tr>\n",
" <tr>\n",
" <th>fakhoury</th>\n",
" <th>Arabic</th>\n",
" <td>28</td>\n",
" <td>28</td>\n",
" </tr>\n",
" <tr>\n",
" <th>koury</th>\n",
" <th>Arabic</th>\n",
" <td>27</td>\n",
" <td>27</td>\n",
" </tr>\n",
" <tr>\n",
" <th>nader</th>\n",
" <th>Arabic</th>\n",
" <td>27</td>\n",
" <td>27</td>\n",
" </tr>\n",
" <tr>\n",
" <th>sarraf</th>\n",
" <th>Arabic</th>\n",
" <td>26</td>\n",
" <td>26</td>\n",
" </tr>\n",
" <tr>\n",
" <th>hadad</th>\n",
" <th>Arabic</th>\n",
" <td>26</td>\n",
" <td>26</td>\n",
" </tr>\n",
" <tr>\n",
" <th>kassis</th>\n",
" <th>Arabic</th>\n",
" <td>26</td>\n",
" <td>26</td>\n",
" </tr>\n",
" <tr>\n",
" <th>antar</th>\n",
" <th>Arabic</th>\n",
" <td>26</td>\n",
" <td>26</td>\n",
" </tr>\n",
" <tr>\n",
" <th>shadid</th>\n",
" <th>Arabic</th>\n",
" <td>25</td>\n",
" <td>25</td>\n",
" </tr>\n",
" <tr>\n",
" <th>cham</th>\n",
" <th>Arabic</th>\n",
" <td>25</td>\n",
" <td>25</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mifsud</th>\n",
" <th>Arabic</th>\n",
" <td>25</td>\n",
" <td>25</td>\n",
" </tr>\n",
" <tr>\n",
" <th>nahas</th>\n",
" <th>Arabic</th>\n",
" <td>24</td>\n",
" <td>24</td>\n",
" </tr>\n",
" <tr>\n",
" <th>gerges</th>\n",
" <th>Arabic</th>\n",
" <td>24</td>\n",
" <td>24</td>\n",
" </tr>\n",
" <tr>\n",
" <th>ganim</th>\n",
" <th>Arabic</th>\n",
" <td>23</td>\n",
" <td>23</td>\n",
" </tr>\n",
" <tr>\n",
" <th>tuma</th>\n",
" <th>Arabic</th>\n",
" <td>23</td>\n",
" <td>23</td>\n",
" </tr>\n",
" <tr>\n",
" <th>to the first page</th>\n",
" <th>Russian</th>\n",
" <td>23</td>\n",
" <td>23</td>\n",
" </tr>\n",
" <tr>\n",
" <th>atiyeh</th>\n",
" <th>Arabic</th>\n",
" <td>23</td>\n",
" <td>23</td>\n",
" </tr>\n",
" <tr>\n",
" <th>malouf</th>\n",
" <th>Arabic</th>\n",
" <td>23</td>\n",
" <td>23</td>\n",
" </tr>\n",
" <tr>\n",
" <th>sayegh</th>\n",
" <th>Arabic</th>\n",
" <td>22</td>\n",
" <td>22</td>\n",
" </tr>\n",
" <tr>\n",
" <th>naifeh</th>\n",
" <th>Arabic</th>\n",
" <td>22</td>\n",
" <td>22</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" name n\n",
"ascii_name cl \n",
"tahan Arabic 28 28\n",
"fakhoury Arabic 28 28\n",
"koury Arabic 27 27\n",
"nader Arabic 27 27\n",
"sarraf Arabic 26 26\n",
"hadad Arabic 26 26\n",
"kassis Arabic 26 26\n",
"antar Arabic 26 26\n",
"shadid Arabic 25 25\n",
"cham Arabic 25 25\n",
"mifsud Arabic 25 25\n",
"nahas Arabic 24 24\n",
"gerges Arabic 24 24\n",
"ganim Arabic 23 23\n",
"tuma Arabic 23 23\n",
"to the first page Russian 23 23\n",
"atiyeh Arabic 23 23\n",
"malouf Arabic 23 23\n",
"sayegh Arabic 22 22\n",
"naifeh Arabic 22 22"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"counts = df.assign(n=1).groupby(['ascii_name', 'cl']).count().sort_values('n', ascending=False)\n",
"counts.head(n=20)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's remove the \"To The First Page\" junk (probably some artifact of where the data was scraped from)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"df = df[df.ascii_name != 'to the first page']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There are no multiples in English, and a lot in Arabic. It seems like a data entry error rather than meaningful."
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>name</th>\n",
" <th>n</th>\n",
" <th>multiple</th>\n",
" <th>rows</th>\n",
" </tr>\n",
" <tr>\n",
" <th>cl</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Russian</th>\n",
" <td>9326</td>\n",
" <td>9326</td>\n",
" <td>35.0</td>\n",
" <td>9263</td>\n",
" </tr>\n",
" <tr>\n",
" <th>English</th>\n",
" <td>3359</td>\n",
" <td>3359</td>\n",
" <td>0.0</td>\n",
" <td>3359</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Arabic</th>\n",
" <td>1892</td>\n",
" <td>1892</td>\n",
" <td>103.0</td>\n",
" <td>103</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Japanese</th>\n",
" <td>983</td>\n",
" <td>983</td>\n",
" <td>1.0</td>\n",
" <td>982</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Italian</th>\n",
" <td>665</td>\n",
" <td>665</td>\n",
" <td>5.0</td>\n",
" <td>660</td>\n",
" </tr>\n",
" <tr>\n",
" <th>German</th>\n",
" <td>613</td>\n",
" <td>613</td>\n",
" <td>33.0</td>\n",
" <td>578</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Czech</th>\n",
" <td>480</td>\n",
" <td>480</td>\n",
" <td>16.0</td>\n",
" <td>464</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Dutch</th>\n",
" <td>255</td>\n",
" <td>255</td>\n",
" <td>10.0</td>\n",
" <td>244</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Chinese</th>\n",
" <td>219</td>\n",
" <td>219</td>\n",
" <td>19.0</td>\n",
" <td>200</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Spanish</th>\n",
" <td>214</td>\n",
" <td>214</td>\n",
" <td>2.0</td>\n",
" <td>212</td>\n",
" </tr>\n",
" <tr>\n",
" <th>French</th>\n",
" <td>213</td>\n",
" <td>213</td>\n",
" <td>3.0</td>\n",
" <td>210</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Greek</th>\n",
" <td>195</td>\n",
" <td>195</td>\n",
" <td>2.0</td>\n",
" <td>192</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Irish</th>\n",
" <td>170</td>\n",
" <td>170</td>\n",
" <td>6.0</td>\n",
" <td>164</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Polish</th>\n",
" <td>124</td>\n",
" <td>124</td>\n",
" <td>1.0</td>\n",
" <td>123</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Korean</th>\n",
" <td>61</td>\n",
" <td>61</td>\n",
" <td>0.0</td>\n",
" <td>61</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Vietnamese</th>\n",
" <td>56</td>\n",
" <td>56</td>\n",
" <td>1.0</td>\n",
" <td>55</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Portuguese</th>\n",
" <td>32</td>\n",
" <td>32</td>\n",
" <td>0.0</td>\n",
" <td>32</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Scottish</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0.0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" name n multiple rows\n",
"cl \n",
"Russian 9326 9326 35.0 9263\n",
"English 3359 3359 0.0 3359\n",
"Arabic 1892 1892 103.0 103\n",
"Japanese 983 983 1.0 982\n",
"Italian 665 665 5.0 660\n",
"German 613 613 33.0 578\n",
"Czech 480 480 16.0 464\n",
"Dutch 255 255 10.0 244\n",
"Chinese 219 219 19.0 200\n",
"Spanish 214 214 2.0 212\n",
"French 213 213 3.0 210\n",
"Greek 195 195 2.0 192\n",
"Irish 170 170 6.0 164\n",
"Polish 124 124 1.0 123\n",
"Korean 61 61 0.0 61\n",
"Vietnamese 56 56 1.0 55\n",
"Portuguese 32 32 0.0 32\n",
"Scottish 1 1 0.0 1"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"counts.assign(multiple=counts.n > 1, rows=1).groupby('cl').sum().sort_values('n', ascending=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It makes sense to drop the duplicates and only have a single row per `ascii_name` and `cl`."
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"df = df.drop_duplicates(['ascii_name', 'cl'])"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"16902"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Length Check\n",
"It's worth checking if the shortest and longest names make sense.\n",
"\n",
"They look reasonable."
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cl</th>\n",
" <th>name</th>\n",
" <th>ascii_name</th>\n",
" <th>len</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>3265</th>\n",
" <td>Vietnamese</td>\n",
" <td>An</td>\n",
" <td>an</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50</th>\n",
" <td>Korean</td>\n",
" <td>Oh</td>\n",
" <td>oh</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1150</th>\n",
" <td>Japanese</td>\n",
" <td>Ii</td>\n",
" <td>ii</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>54</th>\n",
" <td>Korean</td>\n",
" <td>Ra</td>\n",
" <td>ra</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3891</th>\n",
" <td>Arabic</td>\n",
" <td>Ba</td>\n",
" <td>ba</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>57</th>\n",
" <td>Korean</td>\n",
" <td>Ri</td>\n",
" <td>ri</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>69</th>\n",
" <td>Korean</td>\n",
" <td>Si</td>\n",
" <td>si</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>71</th>\n",
" <td>Korean</td>\n",
" <td>So</td>\n",
" <td>so</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3311</th>\n",
" <td>Vietnamese</td>\n",
" <td>To</td>\n",
" <td>to</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>85</th>\n",
" <td>Korean</td>\n",
" <td>Yi</td>\n",
" <td>yi</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11475</th>\n",
" <td>Russian</td>\n",
" <td>Bakhtchivandzhi</td>\n",
" <td>bakhtchivandzhi</td>\n",
" <td>15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10191</th>\n",
" <td>Russian</td>\n",
" <td>Abdulladzhanoff</td>\n",
" <td>abdulladzhanoff</td>\n",
" <td>15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17299</th>\n",
" <td>Russian</td>\n",
" <td>Shakhnazaryants</td>\n",
" <td>shakhnazaryants</td>\n",
" <td>15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11393</th>\n",
" <td>Russian</td>\n",
" <td>Baistryutchenko</td>\n",
" <td>baistryutchenko</td>\n",
" <td>15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14965</th>\n",
" <td>Russian</td>\n",
" <td>Katzenellenbogen</td>\n",
" <td>katzenellenbogen</td>\n",
" <td>16</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2228</th>\n",
" <td>Dutch</td>\n",
" <td>Vandroogenbroeck</td>\n",
" <td>vandroogenbroeck</td>\n",
" <td>16</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14947</th>\n",
" <td>Russian</td>\n",
" <td>Katsenellenbogen</td>\n",
" <td>katsenellenbogen</td>\n",
" <td>16</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19552</th>\n",
" <td>Greek</td>\n",
" <td>Chrysanthopoulos</td>\n",
" <td>chrysanthopoulos</td>\n",
" <td>16</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2841</th>\n",
" <td>Irish</td>\n",
" <td>Maceachthighearna</td>\n",
" <td>maceachthighearna</td>\n",
" <td>17</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6380</th>\n",
" <td>German</td>\n",
" <td>Von grimmelshausen</td>\n",
" <td>von grimmelshausen</td>\n",
" <td>18</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>16902 rows × 4 columns</p>\n",
"</div>"
],
"text/plain": [
" cl name ascii_name len\n",
"3265 Vietnamese An an 2\n",
"50 Korean Oh oh 2\n",
"1150 Japanese Ii ii 2\n",
"54 Korean Ra ra 2\n",
"3891 Arabic Ba ba 2\n",
"57 Korean Ri ri 2\n",
"69 Korean Si si 2\n",
"71 Korean So so 2\n",
"3311 Vietnamese To to 2\n",
"85 Korean Yi yi 2\n",
"... ... ... ... ...\n",
"11475 Russian Bakhtchivandzhi bakhtchivandzhi 15\n",
"10191 Russian Abdulladzhanoff abdulladzhanoff 15\n",
"17299 Russian Shakhnazaryants shakhnazaryants 15\n",
"11393 Russian Baistryutchenko baistryutchenko 15\n",
"14965 Russian Katzenellenbogen katzenellenbogen 16\n",
"2228 Dutch Vandroogenbroeck vandroogenbroeck 16\n",
"14947 Russian Katsenellenbogen katsenellenbogen 16\n",
"19552 Greek Chrysanthopoulos chrysanthopoulos 16\n",
"2841 Irish Maceachthighearna maceachthighearna 17\n",
"6380 German Von grimmelshausen von grimmelshausen 18\n",
"\n",
"[16902 rows x 4 columns]"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.assign(len=df.name.str.len()).sort_values('len')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Distribution by Language"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The dataset is very unbalanced.\n",
"\n",
"I doubt there's enough data to tacke Portuguese (which will be close to Spanish) and Scottish (which will be close to English)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"cl\n",
"Russian 9262\n",
"English 3359\n",
"Japanese 982\n",
"Italian 660\n",
"German 578\n",
"Czech 464\n",
"Dutch 244\n",
"Spanish 212\n",
"French 210\n",
"Chinese 200\n",
"Greek 192\n",
"Irish 164\n",
"Polish 123\n",
"Arabic 103\n",
"Korean 61\n",
"Vietnamese 55\n",
"Portuguese 32\n",
"Scottish 1\n",
"Name: name, dtype: int64"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.groupby('cl').name.count().sort_values(ascending=False)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cl</th>\n",
" <th>name</th>\n",
" <th>ascii_name</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>3711</th>\n",
" <td>Scottish</td>\n",
" <td>Hay</td>\n",
" <td>hay</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" cl name ascii_name\n",
"3711 Scottish Hay hay"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df[df.cl.isin(['Scottish'])]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's remove the rarest classes; we're not likely to have enough data to guess them."
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"df = df[~df.cl.isin(['Scottish', 'Portuguese'])]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note Russian contains variant transliterations to English like Abaimoff and Abaimov (which both correspond to Абаимов).\n",
"\n",
"But this doesn't quite explain it's high frequency: it seems a lot more Russian data was extracted.\n",
"\n",
"(Side note: [Chebyshev](https://en.wikipedia.org/wiki/Pafnuty_Chebyshev) can also be spelt e.g. Chebychev, Tchebycheff, Tschebyschef)"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cl</th>\n",
" <th>name</th>\n",
" <th>ascii_name</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>10112</th>\n",
" <td>Russian</td>\n",
" <td>Ababko</td>\n",
" <td>ababko</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10113</th>\n",
" <td>Russian</td>\n",
" <td>Abaev</td>\n",
" <td>abaev</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10114</th>\n",
" <td>Russian</td>\n",
" <td>Abagyan</td>\n",
" <td>abagyan</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10115</th>\n",
" <td>Russian</td>\n",
" <td>Abaidulin</td>\n",
" <td>abaidulin</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10116</th>\n",
" <td>Russian</td>\n",
" <td>Abaidullin</td>\n",
" <td>abaidullin</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10117</th>\n",
" <td>Russian</td>\n",
" <td>Abaimoff</td>\n",
" <td>abaimoff</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10118</th>\n",
" <td>Russian</td>\n",
" <td>Abaimov</td>\n",
" <td>abaimov</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10119</th>\n",
" <td>Russian</td>\n",
" <td>Abakeliya</td>\n",
" <td>abakeliya</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10120</th>\n",
" <td>Russian</td>\n",
" <td>Abakovsky</td>\n",
" <td>abakovsky</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10121</th>\n",
" <td>Russian</td>\n",
" <td>Abakshin</td>\n",
" <td>abakshin</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19510</th>\n",
" <td>Russian</td>\n",
" <td>Zolotavin</td>\n",
" <td>zolotavin</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19511</th>\n",
" <td>Russian</td>\n",
" <td>Zolotdinov</td>\n",
" <td>zolotdinov</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19512</th>\n",
" <td>Russian</td>\n",
" <td>Zolotenkov</td>\n",
" <td>zolotenkov</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19513</th>\n",
" <td>Russian</td>\n",
" <td>Zolotilin</td>\n",
" <td>zolotilin</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19514</th>\n",
" <td>Russian</td>\n",
" <td>Zolotkov</td>\n",
" <td>zolotkov</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19515</th>\n",
" <td>Russian</td>\n",
" <td>Zolotnitsky</td>\n",
" <td>zolotnitsky</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19516</th>\n",
" <td>Russian</td>\n",
" <td>Zolotnitzky</td>\n",
" <td>zolotnitzky</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19517</th>\n",
" <td>Russian</td>\n",
" <td>Zozrov</td>\n",
" <td>zozrov</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19518</th>\n",
" <td>Russian</td>\n",
" <td>Zozulya</td>\n",
" <td>zozulya</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19519</th>\n",
" <td>Russian</td>\n",
" <td>Zukerman</td>\n",
" <td>zukerman</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>9262 rows × 3 columns</p>\n",
"</div>"
],
"text/plain": [
" cl name ascii_name\n",
"10112 Russian Ababko ababko\n",
"10113 Russian Abaev abaev\n",
"10114 Russian Abagyan abagyan\n",
"10115 Russian Abaidulin abaidulin\n",
"10116 Russian Abaidullin abaidullin\n",
"10117 Russian Abaimoff abaimoff\n",
"10118 Russian Abaimov abaimov\n",
"10119 Russian Abakeliya abakeliya\n",
"10120 Russian Abakovsky abakovsky\n",
"10121 Russian Abakshin abakshin\n",
"... ... ... ...\n",
"19510 Russian Zolotavin zolotavin\n",
"19511 Russian Zolotdinov zolotdinov\n",
"19512 Russian Zolotenkov zolotenkov\n",
"19513 Russian Zolotilin zolotilin\n",
"19514 Russian Zolotkov zolotkov\n",
"19515 Russian Zolotnitsky zolotnitsky\n",
"19516 Russian Zolotnitzky zolotnitzky\n",
"19517 Russian Zozrov zozrov\n",
"19518 Russian Zozulya zozulya\n",
"19519 Russian Zukerman zukerman\n",
"\n",
"[9262 rows x 3 columns]"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df[df.cl == 'Russian']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create Validation and Training Sets\n",
"\n",
"We want our final model to work well on any language.\n",
"\n",
"But if we pick our validation set uniformly at random from the data we're likely to get many Russian names and not many Vietnamese names, which isn't a good test of this.\n",
"\n",
"So instead we'll take our validation set from an equal number from each subclass."
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cl</th>\n",
" <th>name</th>\n",
" <th>ascii_name</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Korean</td>\n",
" <td>Ahn</td>\n",
" <td>ahn</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Korean</td>\n",
" <td>Baik</td>\n",
" <td>baik</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Korean</td>\n",
" <td>Bang</td>\n",
" <td>bang</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Korean</td>\n",
" <td>Byon</td>\n",
" <td>byon</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Korean</td>\n",
" <td>Cha</td>\n",
" <td>cha</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>Korean</td>\n",
" <td>Cho</td>\n",
" <td>cho</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>Korean</td>\n",
" <td>Choe</td>\n",
" <td>choe</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>Korean</td>\n",
" <td>Choi</td>\n",
" <td>choi</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>Korean</td>\n",
" <td>Chun</td>\n",
" <td>chun</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>Korean</td>\n",
" <td>Chweh</td>\n",
" <td>chweh</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16859</th>\n",
" <td>French</td>\n",
" <td>Travere</td>\n",
" <td>travere</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16860</th>\n",
" <td>French</td>\n",
" <td>Traverse</td>\n",
" <td>traverse</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16861</th>\n",
" <td>French</td>\n",
" <td>Travert</td>\n",
" <td>travert</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16862</th>\n",
" <td>French</td>\n",
" <td>Tremblay</td>\n",
" <td>tremblay</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16863</th>\n",
" <td>French</td>\n",
" <td>Tremble</td>\n",
" <td>tremble</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16864</th>\n",
" <td>French</td>\n",
" <td>Victors</td>\n",
" <td>victors</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16865</th>\n",
" <td>French</td>\n",
" <td>Villeneuve</td>\n",
" <td>villeneuve</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16866</th>\n",
" <td>French</td>\n",
" <td>Vipond</td>\n",
" <td>vipond</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16867</th>\n",
" <td>French</td>\n",
" <td>Voclain</td>\n",
" <td>voclain</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16868</th>\n",
" <td>French</td>\n",
" <td>Yount</td>\n",
" <td>yount</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>16869 rows × 3 columns</p>\n",
"</div>"
],
"text/plain": [
" cl name ascii_name\n",
"0 Korean Ahn ahn\n",
"1 Korean Baik baik\n",
"2 Korean Bang bang\n",
"3 Korean Byon byon\n",
"4 Korean Cha cha\n",
"5 Korean Cho cho\n",
"6 Korean Choe choe\n",
"7 Korean Choi choi\n",
"8 Korean Chun chun\n",
"9 Korean Chweh chweh\n",
"... ... ... ...\n",
"16859 French Travere travere\n",
"16860 French Traverse traverse\n",
"16861 French Travert travert\n",
"16862 French Tremblay tremblay\n",
"16863 French Tremble tremble\n",
"16864 French Victors victors\n",
"16865 French Villeneuve villeneuve\n",
"16866 French Vipond vipond\n",
"16867 French Voclain voclain\n",
"16868 French Yount yount\n",
"\n",
"[16869 rows x 3 columns]"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = df.reset_index().drop('index', 1)\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"cl\n",
"Russian 9262\n",
"English 3359\n",
"Japanese 982\n",
"Italian 660\n",
"German 578\n",
"Czech 464\n",
"Dutch 244\n",
"Spanish 212\n",
"French 210\n",
"Chinese 200\n",
"Greek 192\n",
"Irish 164\n",
"Polish 123\n",
"Arabic 103\n",
"Korean 61\n",
"Vietnamese 55\n",
"Name: name, dtype: int64"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"counts = df.groupby('cl').name.count().sort_values(ascending=False)\n",
"counts"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"valid_size = 30 # We'll pick 30 at random from each subclass\n",
"train_size = 500 # For a balanced training set we'll pick 500 at random with replacement"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(6011)\n",
"valid_idx = []\n",
"for cl in counts.keys():\n",
" # Random sample of size \"valid_size\" for each class\n",
" valid_idx += list(df[df.cl == cl].sample(valid_size).index)"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
"df['valid'] = False\n",
"df.loc[valid_idx, 'valid'] = True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's also create a balanced training set as an alternative to using everything not in validation"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(7012)\n",
"balanced_idx = []\n",
"for cl in counts.keys():\n",
" # Random sample of size \"train_size\" for each class from the data outside of the validation set\n",
" balanced_idx += list(df[(df.cl == cl) & ~df.valid].sample(train_size, replace=True).index)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note the balanced index contains all 25 (= 55 - 30) Vietnamese names outside of the training set, but only contains 486 of the Russian names (because we sampled randomly with replacement there will be a couple of double ups)."
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cl</th>\n",
" <th>name</th>\n",
" <th>ascii_name</th>\n",
" <th>valid</th>\n",
" </tr>\n",
" <tr>\n",
" <th>cl</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Russian</th>\n",
" <td>1</td>\n",
" <td>486</td>\n",
" <td>486</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>English</th>\n",
" <td>1</td>\n",
" <td>459</td>\n",
" <td>459</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Japanese</th>\n",
" <td>1</td>\n",
" <td>383</td>\n",
" <td>383</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Italian</th>\n",
" <td>1</td>\n",
" <td>357</td>\n",
" <td>357</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>German</th>\n",
" <td>1</td>\n",
" <td>330</td>\n",
" <td>330</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Czech</th>\n",
" <td>1</td>\n",
" <td>295</td>\n",
" <td>295</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Dutch</th>\n",
" <td>1</td>\n",
" <td>195</td>\n",
" <td>195</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>French</th>\n",
" <td>1</td>\n",
" <td>172</td>\n",
" <td>172</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Spanish</th>\n",
" <td>1</td>\n",
" <td>170</td>\n",
" <td>170</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Chinese</th>\n",
" <td>1</td>\n",
" <td>158</td>\n",
" <td>158</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Greek</th>\n",
" <td>1</td>\n",
" <td>153</td>\n",
" <td>153</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Irish</th>\n",
" <td>1</td>\n",
" <td>129</td>\n",
" <td>129</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Polish</th>\n",
" <td>1</td>\n",
" <td>93</td>\n",
" <td>93</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Arabic</th>\n",
" <td>1</td>\n",
" <td>73</td>\n",
" <td>73</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Korean</th>\n",
" <td>1</td>\n",
" <td>31</td>\n",
" <td>31</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Vietnamese</th>\n",
" <td>1</td>\n",
" <td>25</td>\n",
" <td>25</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" cl name ascii_name valid\n",
"cl \n",
"Russian 1 486 486 1\n",
"English 1 459 459 1\n",
"Japanese 1 383 383 1\n",
"Italian 1 357 357 1\n",
"German 1 330 330 1\n",
"Czech 1 295 295 1\n",
"Dutch 1 195 195 1\n",
"French 1 172 172 1\n",
"Spanish 1 170 170 1\n",
"Chinese 1 158 158 1\n",
"Greek 1 153 153 1\n",
"Irish 1 129 129 1\n",
"Polish 1 93 93 1\n",
"Arabic 1 73 73 1\n",
"Korean 1 31 31 1\n",
"Vietnamese 1 25 25 1"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.loc[balanced_idx].groupby('cl').nunique().sort_values('ascii_name', ascending=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's record our balanced set in the dataframe: this will make it easy to reload at a later point."
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"df['bal'] = 0\n",
"for k, v in Counter(balanced_idx).items():\n",
" df.loc[k, 'bal'] += v"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cl</th>\n",
" <th>name</th>\n",
" <th>ascii_name</th>\n",
" <th>valid</th>\n",
" <th>bal</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Korean</td>\n",
" <td>Ahn</td>\n",
" <td>ahn</td>\n",
" <td>False</td>\n",
" <td>13</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Korean</td>\n",
" <td>Baik</td>\n",
" <td>baik</td>\n",
" <td>True</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Korean</td>\n",
" <td>Bang</td>\n",
" <td>bang</td>\n",
" <td>False</td>\n",
" <td>13</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Korean</td>\n",
" <td>Byon</td>\n",
" <td>byon</td>\n",
" <td>False</td>\n",
" <td>15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Korean</td>\n",
" <td>Cha</td>\n",
" <td>cha</td>\n",
" <td>True</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" cl name ascii_name valid bal\n",
"0 Korean Ahn ahn False 13\n",
"1 Korean Baik baik True 0\n",
"2 Korean Bang bang False 13\n",
"3 Korean Byon byon False 15\n",
"4 Korean Cha cha True 0"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can always retrieve the indexes from the dataframe"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"idx = []\n",
"for k, v in zip(df.index, df.bal):\n",
" idx += [k]*v\n",
"sorted(balanced_idx) == idx"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Save the Data"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"df.to_csv('names_clean.csv', index=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Benchmarks\n",
"\n",
"The first benchmark is random guessing/always guessing the same class.\n",
"\n",
"The expected return is 1/(number of classes) = 1/16 ~ 6.25%"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv('names_clean.csv')\n",
"\n",
"valid_idx = df[df.valid].index\n",
"train_idx = df[~df.valid].index\n",
"\n",
"bal_idx = []\n",
"for k, v in zip(df.index, df.bal):\n",
" bal_idx += [k]*v"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sanity check out data\n",
"Check training/balanced training data doesn't contain any names in validation set"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0, 0)"
]
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_intersect_valid = sum(df.iloc[train_idx].ascii_name.isin(df.iloc[valid_idx].ascii_name)) \n",
"bal_interset_valid = sum(df.iloc[bal_idx].ascii_name.isin(df.iloc[valid_idx].ascii_name))\n",
"train_intersect_valid, bal_interset_valid"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Make sure the data looks right"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cl</th>\n",
" <th>name</th>\n",
" <th>ascii_name</th>\n",
" <th>valid</th>\n",
" <th>bal</th>\n",
" </tr>\n",
" <tr>\n",
" <th>cl</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Russian</th>\n",
" <td>1</td>\n",
" <td>9232</td>\n",
" <td>9232</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>English</th>\n",
" <td>1</td>\n",
" <td>3329</td>\n",
" <td>3329</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Japanese</th>\n",
" <td>1</td>\n",
" <td>952</td>\n",
" <td>952</td>\n",
" <td>1</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Italian</th>\n",
" <td>1</td>\n",
" <td>630</td>\n",
" <td>630</td>\n",
" <td>1</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>German</th>\n",
" <td>1</td>\n",
" <td>548</td>\n",
" <td>548</td>\n",
" <td>1</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Czech</th>\n",
" <td>1</td>\n",
" <td>434</td>\n",
" <td>434</td>\n",
" <td>1</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Dutch</th>\n",
" <td>1</td>\n",
" <td>214</td>\n",
" <td>214</td>\n",
" <td>1</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Spanish</th>\n",
" <td>1</td>\n",
" <td>182</td>\n",
" <td>182</td>\n",
" <td>1</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>French</th>\n",
" <td>1</td>\n",
" <td>180</td>\n",
" <td>180</td>\n",
" <td>1</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Chinese</th>\n",
" <td>1</td>\n",
" <td>170</td>\n",
" <td>170</td>\n",
" <td>1</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Greek</th>\n",
" <td>1</td>\n",
" <td>162</td>\n",
" <td>162</td>\n",
" <td>1</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Irish</th>\n",
" <td>1</td>\n",
" <td>134</td>\n",
" <td>134</td>\n",
" <td>1</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Polish</th>\n",
" <td>1</td>\n",
" <td>93</td>\n",
" <td>93</td>\n",
" <td>1</td>\n",
" <td>11</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Arabic</th>\n",
" <td>1</td>\n",
" <td>73</td>\n",
" <td>73</td>\n",
" <td>1</td>\n",
" <td>13</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Korean</th>\n",
" <td>1</td>\n",
" <td>31</td>\n",
" <td>31</td>\n",
" <td>1</td>\n",
" <td>13</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Vietnamese</th>\n",
" <td>1</td>\n",
" <td>25</td>\n",
" <td>25</td>\n",
" <td>1</td>\n",
" <td>16</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" cl name ascii_name valid bal\n",
"cl \n",
"Russian 1 9232 9232 1 3\n",
"English 1 3329 3329 1 4\n",
"Japanese 1 952 952 1 5\n",
"Italian 1 630 630 1 5\n",
"German 1 548 548 1 5\n",
"Czech 1 434 434 1 6\n",
"Dutch 1 214 214 1 9\n",
"Spanish 1 182 182 1 10\n",
"French 1 180 180 1 9\n",
"Chinese 1 170 170 1 9\n",
"Greek 1 162 162 1 10\n",
"Irish 1 134 134 1 10\n",
"Polish 1 93 93 1 11\n",
"Arabic 1 73 73 1 13\n",
"Korean 1 31 31 1 13\n",
"Vietnamese 1 25 25 1 16"
]
},
"execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.iloc[train_idx].groupby('cl').nunique().sort_values('ascii_name', ascending=False)"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cl</th>\n",
" <th>name</th>\n",
" <th>ascii_name</th>\n",
" <th>valid</th>\n",
" <th>bal</th>\n",
" </tr>\n",
" <tr>\n",
" <th>cl</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Russian</th>\n",
" <td>1</td>\n",
" <td>486</td>\n",
" <td>486</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>English</th>\n",
" <td>1</td>\n",
" <td>459</td>\n",
" <td>459</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Japanese</th>\n",
" <td>1</td>\n",
" <td>383</td>\n",
" <td>383</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Italian</th>\n",
" <td>1</td>\n",
" <td>357</td>\n",
" <td>357</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>German</th>\n",
" <td>1</td>\n",
" <td>330</td>\n",
" <td>330</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Czech</th>\n",
" <td>1</td>\n",
" <td>295</td>\n",
" <td>295</td>\n",
" <td>1</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Dutch</th>\n",
" <td>1</td>\n",
" <td>195</td>\n",
" <td>195</td>\n",
" <td>1</td>\n",
" <td>8</td>\n",
" </tr>\n",
" <tr>\n",
" <th>French</th>\n",
" <td>1</td>\n",
" <td>172</td>\n",
" <td>172</td>\n",
" <td>1</td>\n",
" <td>8</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Spanish</th>\n",
" <td>1</td>\n",
" <td>170</td>\n",
" <td>170</td>\n",
" <td>1</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Chinese</th>\n",
" <td>1</td>\n",
" <td>158</td>\n",
" <td>158</td>\n",
" <td>1</td>\n",
" <td>8</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Greek</th>\n",
" <td>1</td>\n",
" <td>153</td>\n",
" <td>153</td>\n",
" <td>1</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Irish</th>\n",
" <td>1</td>\n",
" <td>129</td>\n",
" <td>129</td>\n",
" <td>1</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Polish</th>\n",
" <td>1</td>\n",
" <td>93</td>\n",
" <td>93</td>\n",
" <td>1</td>\n",
" <td>11</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Arabic</th>\n",
" <td>1</td>\n",
" <td>73</td>\n",
" <td>73</td>\n",
" <td>1</td>\n",
" <td>13</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Korean</th>\n",
" <td>1</td>\n",
" <td>31</td>\n",
" <td>31</td>\n",
" <td>1</td>\n",
" <td>13</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Vietnamese</th>\n",
" <td>1</td>\n",
" <td>25</td>\n",
" <td>25</td>\n",
" <td>1</td>\n",
" <td>16</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" cl name ascii_name valid bal\n",
"cl \n",
"Russian 1 486 486 1 2\n",
"English 1 459 459 1 3\n",
"Japanese 1 383 383 1 4\n",
"Italian 1 357 357 1 4\n",
"German 1 330 330 1 4\n",
"Czech 1 295 295 1 5\n",
"Dutch 1 195 195 1 8\n",
"French 1 172 172 1 8\n",
"Spanish 1 170 170 1 9\n",
"Chinese 1 158 158 1 8\n",
"Greek 1 153 153 1 9\n",
"Irish 1 129 129 1 9\n",
"Polish 1 93 93 1 11\n",
"Arabic 1 73 73 1 13\n",
"Korean 1 31 31 1 13\n",
"Vietnamese 1 25 25 1 16"
]
},
"execution_count": 67,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.iloc[bal_idx].groupby('cl').nunique().sort_values('ascii_name', ascending=False)"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cl</th>\n",
" <th>name</th>\n",
" <th>ascii_name</th>\n",
" <th>valid</th>\n",
" <th>bal</th>\n",
" </tr>\n",
" <tr>\n",
" <th>cl</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Arabic</th>\n",
" <td>1</td>\n",
" <td>30</td>\n",
" <td>30</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Chinese</th>\n",
" <td>1</td>\n",
" <td>30</td>\n",
" <td>30</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Czech</th>\n",
" <td>1</td>\n",
" <td>30</td>\n",
" <td>30</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Dutch</th>\n",
" <td>1</td>\n",
" <td>30</td>\n",
" <td>30</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>English</th>\n",
" <td>1</td>\n",
" <td>30</td>\n",
" <td>30</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>French</th>\n",
" <td>1</td>\n",
" <td>30</td>\n",
" <td>30</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>German</th>\n",
" <td>1</td>\n",
" <td>30</td>\n",
" <td>30</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Greek</th>\n",
" <td>1</td>\n",
" <td>30</td>\n",
" <td>30</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Irish</th>\n",
" <td>1</td>\n",
" <td>30</td>\n",
" <td>30</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Italian</th>\n",
" <td>1</td>\n",
" <td>30</td>\n",
" <td>30</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Japanese</th>\n",
" <td>1</td>\n",
" <td>30</td>\n",
" <td>30</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Korean</th>\n",
" <td>1</td>\n",
" <td>30</td>\n",
" <td>30</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Polish</th>\n",
" <td>1</td>\n",
" <td>30</td>\n",
" <td>30</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Russian</th>\n",
" <td>1</td>\n",
" <td>30</td>\n",
" <td>30</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Spanish</th>\n",
" <td>1</td>\n",
" <td>30</td>\n",
" <td>30</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Vietnamese</th>\n",
" <td>1</td>\n",
" <td>30</td>\n",
" <td>30</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" cl name ascii_name valid bal\n",
"cl \n",
"Arabic 1 30 30 1 1\n",
"Chinese 1 30 30 1 1\n",
"Czech 1 30 30 1 1\n",
"Dutch 1 30 30 1 1\n",
"English 1 30 30 1 1\n",
"French 1 30 30 1 1\n",
"German 1 30 30 1 1\n",
"Greek 1 30 30 1 1\n",
"Irish 1 30 30 1 1\n",
"Italian 1 30 30 1 1\n",
"Japanese 1 30 30 1 1\n",
"Korean 1 30 30 1 1\n",
"Polish 1 30 30 1 1\n",
"Russian 1 30 30 1 1\n",
"Spanish 1 30 30 1 1\n",
"Vietnamese 1 30 30 1 1"
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.iloc[valid_idx].groupby('cl').nunique().sort_values('ascii_name', ascending=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Picking any one class in validation will give 1/16 = 6.25%"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.0625"
]
},
"execution_count": 69,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(df[df.valid] == 'Korean').cl.sum() / df.valid.sum()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## n-grams and naive Bayes\n",
"\n",
"A reasonable way to guess a language is by the frequency of characters and pairs of characters.\n",
"\n",
"For example 'cz' is very rare in English, but quite common in the slavic languages."
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"name = 'zozrov'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A function to count the occurances of sequences of one, two or three letters (in general these sequences are called \"n-grams\" particularly when referring to sequences of words)."
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Counter({'z': 2, 'o': 2, 'r': 1, 'v': 1}),\n",
" Counter({'zo': 1, 'oz': 1, 'zr': 1, 'ro': 1, 'ov': 1}),\n",
" Counter({'zoz': 1, 'ozr': 1, 'zro': 1, 'rov': 1}))"
]
},
"execution_count": 71,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def ngrams(s,n=1):\n",
" parts = [s[i:] for i in range(n)] # e.g. ['zozrov', 'ozrov', 'zrov']\n",
" return Counter(''.join(_) for _ in zip(*parts))\n",
"\n",
"ngrams(name, 1), ngrams(name, 2), ngrams(name, 3)"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [],
"source": [
"df = df.assign(letters=df.ascii_name.apply(ngrams))\n",
"df = df.assign(bigrams=df.ascii_name.apply(ngrams, n=2))\n",
"df = df.assign(trigrams=df.ascii_name.apply(ngrams, n=3))"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cl</th>\n",
" <th>name</th>\n",
" <th>ascii_name</th>\n",
" <th>valid</th>\n",
" <th>bal</th>\n",
" <th>letters</th>\n",
" <th>bigrams</th>\n",
" <th>trigrams</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Korean</td>\n",
" <td>Ahn</td>\n",
" <td>ahn</td>\n",
" <td>False</td>\n",
" <td>13</td>\n",
" <td>{'a': 1, 'h': 1, 'n': 1}</td>\n",
" <td>{'ah': 1, 'hn': 1}</td>\n",
" <td>{'ahn': 1}</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Korean</td>\n",
" <td>Baik</td>\n",
" <td>baik</td>\n",
" <td>True</td>\n",
" <td>0</td>\n",
" <td>{'b': 1, 'a': 1, 'i': 1, 'k': 1}</td>\n",
" <td>{'ba': 1, 'ai': 1, 'ik': 1}</td>\n",
" <td>{'bai': 1, 'aik': 1}</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Korean</td>\n",
" <td>Bang</td>\n",
" <td>bang</td>\n",
" <td>False</td>\n",
" <td>13</td>\n",
" <td>{'b': 1, 'a': 1, 'n': 1, 'g': 1}</td>\n",
" <td>{'ba': 1, 'an': 1, 'ng': 1}</td>\n",
" <td>{'ban': 1, 'ang': 1}</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Korean</td>\n",
" <td>Byon</td>\n",
" <td>byon</td>\n",
" <td>False</td>\n",
" <td>15</td>\n",
" <td>{'b': 1, 'y': 1, 'o': 1, 'n': 1}</td>\n",
" <td>{'by': 1, 'yo': 1, 'on': 1}</td>\n",
" <td>{'byo': 1, 'yon': 1}</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Korean</td>\n",
" <td>Cha</td>\n",
" <td>cha</td>\n",
" <td>True</td>\n",
" <td>0</td>\n",
" <td>{'c': 1, 'h': 1, 'a': 1}</td>\n",
" <td>{'ch': 1, 'ha': 1}</td>\n",
" <td>{'cha': 1}</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" cl name ascii_name valid bal letters \\\n",
"0 Korean Ahn ahn False 13 {'a': 1, 'h': 1, 'n': 1} \n",
"1 Korean Baik baik True 0 {'b': 1, 'a': 1, 'i': 1, 'k': 1} \n",
"2 Korean Bang bang False 13 {'b': 1, 'a': 1, 'n': 1, 'g': 1} \n",
"3 Korean Byon byon False 15 {'b': 1, 'y': 1, 'o': 1, 'n': 1} \n",
"4 Korean Cha cha True 0 {'c': 1, 'h': 1, 'a': 1} \n",
"\n",
" bigrams trigrams \n",
"0 {'ah': 1, 'hn': 1} {'ahn': 1} \n",
"1 {'ba': 1, 'ai': 1, 'ik': 1} {'bai': 1, 'aik': 1} \n",
"2 {'ba': 1, 'an': 1, 'ng': 1} {'ban': 1, 'ang': 1} \n",
"3 {'by': 1, 'yo': 1, 'on': 1} {'byo': 1, 'yon': 1} \n",
"4 {'ch': 1, 'ha': 1} {'cha': 1} "
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's try to guess the name using [Naive Bayes](https://en.wikipedia.org/wiki/Naive_Bayes_classifier).\n",
"\n",
"TL;DR: This is a *really* simple model that works quite well and will give a good benchmark.\n",
"\n",
"This uses \"Bayes Rule\" which uses the data to answer questions like: \"given the name contains the bigram 'ah' what's the probability it's Korean?\".\n",
"\n",
"The \"Naive\" part means that that we assume all these probabilities are independent (knowing it contains 'ah' doesn't tell you anything about the fact it contains 'hn'). Even though this definitely isn't true, it's often a reasonable approximation.\n",
"\n",
"This makes it really fast and simple to fit a model and often works well."
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.naive_bayes import MultinomialNB\n",
"from sklearn.feature_extraction import DictVectorizer\n",
"vd1 = DictVectorizer(sparse=False)\n",
"vd2 = DictVectorizer(sparse=False)\n",
"vd3 = DictVectorizer(sparse=False)"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [],
"source": [
"y = df.cl"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [],
"source": [
"letters = vd1.fit_transform(df.letters)\n",
"bigrams = vd2.fit_transform(df.bigrams)\n",
"trigrams = vd3.fit_transform(df.trigrams)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The letters matrix contains the number of times each of the 28 letters occurs (e.g. number of spaces, number of apostrophes, number of 'a', ...)."
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[' ', \"'\", 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']"
]
},
"execution_count": 89,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vd1.get_feature_names()[:10]"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0., 0., 1., 0., ..., 0., 0., 0., 0.],\n",
" [0., 0., 1., 1., ..., 0., 0., 0., 0.],\n",
" [0., 0., 1., 1., ..., 0., 0., 0., 0.],\n",
" [0., 0., 0., 1., ..., 0., 0., 1., 0.],\n",
" ...,\n",
" [0., 0., 0., 0., ..., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., ..., 0., 0., 0., 0.],\n",
" [0., 0., 1., 0., ..., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., ..., 0., 0., 1., 0.]])"
]
},
"execution_count": 90,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"letters"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Similarly bigrams and trigrams contains the number of times each sequence of 2 or 3 letters occurs"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"([' a', ' b', ' c', ' e', ' f'], ['zu', 'zv', 'zw', 'zy', 'zz'])"
]
},
"execution_count": 94,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vd2.get_feature_names()[:5], vd2.get_feature_names()[-5:]"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((16869, 28), (16869, 623), (16869, 5794), (16869,))"
]
},
"execution_count": 95,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"letters.shape, bigrams.shape, trigrams.shape, y.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"How good a model can we get looking at individual letters (e.g. saying 'z' occurs much more frequently in Chinese than in English names)."
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"MultinomialNB(alpha=1.0, class_prior=None, fit_prior=True)"
]
},
"execution_count": 96,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"letter_nb = MultinomialNB()\n",
"letter_nb.fit(letters[train_idx],y[train_idx])\n",
"\n",
"bal_letter_nb = MultinomialNB()\n",
"bal_letter_nb.fit(letters[bal_idx],y[bal_idx])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The balanced set does mut better than random; around 33%"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.14791666666666667, 0.33541666666666664)"
]
},
"execution_count": 97,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"letter_pred = letter_nb.predict(letters[valid_idx])\n",
"bal_letter_pred = bal_letter_nb.predict(letters[valid_idx])\n",
"(letter_pred == y[valid_idx]).mean(), (bal_letter_pred == y[valid_idx]).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's write a function to test the Naive Bayes on any dataset; fitting on the whole dataset and the balanced dataset separately."
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {},
"outputs": [],
"source": [
"def nb(x):\n",
" model = MultinomialNB()\n",
" model.fit(x[train_idx], y[train_idx])\n",
" preds = model.predict(x[valid_idx])\n",
" acc_train = (preds == y[valid_idx]).mean()\n",
" \n",
" model = MultinomialNB()\n",
" model.fit(x[bal_idx], y[bal_idx])\n",
" preds = model.predict(x[valid_idx])\n",
" acc_bal = (preds == y[valid_idx]).mean()\n",
" \n",
" return acc_train, acc_bal"
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.14791666666666667, 0.33541666666666664)"
]
},
"execution_count": 99,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nb(letters)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Using bigrams and a balanced training set gives a much better prediction performance 53% (up from the baseline of 6.25%)."
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.35833333333333334, 0.5291666666666667)"
]
},
"execution_count": 100,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nb(bigrams)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Adding letters doesn't make much difference (which isn't surprising "
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.3854166666666667, 0.5166666666666667)"
]
},
"execution_count": 101,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nb(np.concatenate((letters, bigrams), axis=1))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Trigrams alone also performs worse"
]
},
{
"cell_type": "code",
"execution_count": 102,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.33958333333333335, 0.4895833333333333)"
]
},
"execution_count": 102,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nb(trigrams)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's try every combination with trigrams:"
]
},
{
"cell_type": "code",
"execution_count": 103,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.24375, 0.5083333333333333)"
]
},
"execution_count": 103,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nb(np.concatenate((letters, trigrams), axis=1))"
]
},
{
"cell_type": "code",
"execution_count": 104,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.36875, 0.5416666666666666)"
]
},
"execution_count": 104,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nb(np.concatenate((bigrams, trigrams), axis=1))"
]
},
{
"cell_type": "code",
"execution_count": 105,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.32916666666666666, 0.55625)"
]
},
"execution_count": 105,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nb(np.concatenate((letters, bigrams, trigrams), axis=1))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"None of them significantly outperform the simple bigram model (with 623 parameters; we could probably remove some of the uncommon ones without too many problems."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Examining the Bigram Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's remove the bigrams that only occur once as they have practically no value (and there's 100 of them)."
]
},
{
"cell_type": "code",
"execution_count": 195,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"503"
]
},
"execution_count": 195,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"common_bigrams = (bigrams[bal_idx].sum(axis=0)) >= 2\n",
"common_bigrams.sum()"
]
},
{
"cell_type": "code",
"execution_count": 196,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(16869, 503)"
]
},
"execution_count": 196,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"common_bigram_index = [i for i, t in enumerate(common_bigrams) if t]\n",
"bigrams_min = bigrams[:, common_bigram_index]\n",
"bigrams_min.shape"
]
},
{
"cell_type": "code",
"execution_count": 197,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"MultinomialNB(alpha=1.0, class_prior=None, fit_prior=True)"
]
},
"execution_count": 197,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bigram_model = MultinomialNB()\n",
"bigram_model.fit(bigrams_min[bal_idx], y[bal_idx])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We get around 53% accuracy."
]
},
{
"cell_type": "code",
"execution_count": 203,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.5291666666666667"
]
},
"execution_count": 203,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bigram_pred = bigram_model.predict(bigrams_min[valid_idx])\n",
"(bigram_pred == y[valid_idx]).mean()"
]
},
{
"cell_type": "code",
"execution_count": 213,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([244, 9, 20, 18, ..., 86, 422, 143, 431])"
]
},
"execution_count": 213,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bigram_prob = bigram_model.predict_proba(bigrams_min[valid_idx])\n",
"bigram_prob.max(axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 217,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>name</th>\n",
" <th>cl</th>\n",
" <th>pred</th>\n",
" <th>prob</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>16557</th>\n",
" <td>Kotsiopoulos</td>\n",
" <td>Greek</td>\n",
" <td>Greek</td>\n",
" <td>1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2012</th>\n",
" <td>Rooijakker</td>\n",
" <td>Dutch</td>\n",
" <td>Dutch</td>\n",
" <td>1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16470</th>\n",
" <td>Akrivopoulos</td>\n",
" <td>Greek</td>\n",
" <td>Greek</td>\n",
" <td>0.999999</td>\n",
" </tr>\n",
" <tr>\n",
" <th>826</th>\n",
" <td>Warszawski</td>\n",
" <td>Polish</td>\n",
" <td>Polish</td>\n",
" <td>0.999998</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1997</th>\n",
" <td>Romeijnders</td>\n",
" <td>Dutch</td>\n",
" <td>Dutch</td>\n",
" <td>0.999997</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16478</th>\n",
" <td>Antonopoulos</td>\n",
" <td>Greek</td>\n",
" <td>Greek</td>\n",
" <td>0.999995</td>\n",
" </tr>\n",
" <tr>\n",
" <th>813</th>\n",
" <td>Sokolowski</td>\n",
" <td>Polish</td>\n",
" <td>Polish</td>\n",
" <td>0.999994</td>\n",
" </tr>\n",
" <tr>\n",
" <th>839</th>\n",
" <td>Zdunowski</td>\n",
" <td>Polish</td>\n",
" <td>Polish</td>\n",
" <td>0.999950</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1996</th>\n",
" <td>Romeijn</td>\n",
" <td>Dutch</td>\n",
" <td>Dutch</td>\n",
" <td>0.999917</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16497</th>\n",
" <td>Chrysanthopoulos</td>\n",
" <td>Greek</td>\n",
" <td>Greek</td>\n",
" <td>0.999895</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2053</th>\n",
" <td>Sneijers</td>\n",
" <td>Dutch</td>\n",
" <td>Dutch</td>\n",
" <td>0.999792</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2031</th>\n",
" <td>Schwarzenberg</td>\n",
" <td>Dutch</td>\n",
" <td>German</td>\n",
" <td>0.999774</td>\n",
" </tr>\n",
" <tr>\n",
" <th>795</th>\n",
" <td>Rudawski</td>\n",
" <td>Polish</td>\n",
" <td>Polish</td>\n",
" <td>0.999751</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16715</th>\n",
" <td>De sauveterre</td>\n",
" <td>French</td>\n",
" <td>French</td>\n",
" <td>0.999604</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1160</th>\n",
" <td>Kawagichi</td>\n",
" <td>Japanese</td>\n",
" <td>Japanese</td>\n",
" <td>0.999600</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" name cl pred prob\n",
"16557 Kotsiopoulos Greek Greek 1.000000\n",
"2012 Rooijakker Dutch Dutch 1.000000\n",
"16470 Akrivopoulos Greek Greek 0.999999\n",
"826 Warszawski Polish Polish 0.999998\n",
"1997 Romeijnders Dutch Dutch 0.999997\n",
"16478 Antonopoulos Greek Greek 0.999995\n",
"813 Sokolowski Polish Polish 0.999994\n",
"839 Zdunowski Polish Polish 0.999950\n",
"1996 Romeijn Dutch Dutch 0.999917\n",
"16497 Chrysanthopoulos Greek Greek 0.999895\n",
"2053 Sneijers Dutch Dutch 0.999792\n",
"2031 Schwarzenberg Dutch German 0.999774\n",
"795 Rudawski Polish Polish 0.999751\n",
"16715 De sauveterre French French 0.999604\n",
"1160 Kawagichi Japanese Japanese 0.999600"
]
},
"execution_count": 217,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bigram_preds = (df\n",
" .iloc[valid_idx]\n",
" .assign(pred = bigram_pred)[['name', 'cl', 'pred']]\n",
" .assign(prob = bigram_prob.max(axis=1)))\n",
"bigram_preds.sort_values('prob', ascending=False).head(15)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The names it's least confident with: they typically seem to be quite short"
]
},
{
"cell_type": "code",
"execution_count": 218,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>name</th>\n",
" <th>cl</th>\n",
" <th>pred</th>\n",
" <th>prob</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2907</th>\n",
" <td>Do</td>\n",
" <td>Vietnamese</td>\n",
" <td>Irish</td>\n",
" <td>0.176679</td>\n",
" </tr>\n",
" <tr>\n",
" <th>24</th>\n",
" <td>Mo</td>\n",
" <td>Korean</td>\n",
" <td>Japanese</td>\n",
" <td>0.179534</td>\n",
" </tr>\n",
" <tr>\n",
" <th>47</th>\n",
" <td>So</td>\n",
" <td>Korean</td>\n",
" <td>Korean</td>\n",
" <td>0.188088</td>\n",
" </tr>\n",
" <tr>\n",
" <th>45</th>\n",
" <td>Si</td>\n",
" <td>Korean</td>\n",
" <td>Greek</td>\n",
" <td>0.190236</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13775</th>\n",
" <td>Prigojin</td>\n",
" <td>Russian</td>\n",
" <td>Italian</td>\n",
" <td>0.191639</td>\n",
" </tr>\n",
" <tr>\n",
" <th>41</th>\n",
" <td>Seok</td>\n",
" <td>Korean</td>\n",
" <td>French</td>\n",
" <td>0.197154</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>Cho</td>\n",
" <td>Korean</td>\n",
" <td>German</td>\n",
" <td>0.202991</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1091</th>\n",
" <td>Isobe</td>\n",
" <td>Japanese</td>\n",
" <td>English</td>\n",
" <td>0.206442</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46</th>\n",
" <td>Sin</td>\n",
" <td>Korean</td>\n",
" <td>Italian</td>\n",
" <td>0.218300</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5332</th>\n",
" <td>Ingram</td>\n",
" <td>English</td>\n",
" <td>Spanish</td>\n",
" <td>0.220205</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2935</th>\n",
" <td>Ta</td>\n",
" <td>Vietnamese</td>\n",
" <td>Japanese</td>\n",
" <td>0.226875</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2700</th>\n",
" <td>Ban</td>\n",
" <td>Chinese</td>\n",
" <td>Vietnamese</td>\n",
" <td>0.228022</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1697</th>\n",
" <td>Togo</td>\n",
" <td>Japanese</td>\n",
" <td>Japanese</td>\n",
" <td>0.236658</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Cha</td>\n",
" <td>Korean</td>\n",
" <td>Irish</td>\n",
" <td>0.239172</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3445</th>\n",
" <td>Graner</td>\n",
" <td>German</td>\n",
" <td>Spanish</td>\n",
" <td>0.240844</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" name cl pred prob\n",
"2907 Do Vietnamese Irish 0.176679\n",
"24 Mo Korean Japanese 0.179534\n",
"47 So Korean Korean 0.188088\n",
"45 Si Korean Greek 0.190236\n",
"13775 Prigojin Russian Italian 0.191639\n",
"41 Seok Korean French 0.197154\n",
"5 Cho Korean German 0.202991\n",
"1091 Isobe Japanese English 0.206442\n",
"46 Sin Korean Italian 0.218300\n",
"5332 Ingram English Spanish 0.220205\n",
"2935 Ta Vietnamese Japanese 0.226875\n",
"2700 Ban Chinese Vietnamese 0.228022\n",
"1697 Togo Japanese Japanese 0.236658\n",
"4 Cha Korean Irish 0.239172\n",
"3445 Graner German Spanish 0.240844"
]
},
"execution_count": 218,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bigram_preds.sort_values('prob', ascending=True).head(15)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The names it's most confidently wrong with:"
]
},
{
"cell_type": "code",
"execution_count": 222,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>name</th>\n",
" <th>cl</th>\n",
" <th>pred</th>\n",
" <th>prob</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2031</th>\n",
" <td>Schwarzenberg</td>\n",
" <td>Dutch</td>\n",
" <td>German</td>\n",
" <td>0.999774</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16578</th>\n",
" <td>Malihoudis</td>\n",
" <td>Greek</td>\n",
" <td>Arabic</td>\n",
" <td>0.992311</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16576</th>\n",
" <td>Louverdis</td>\n",
" <td>Greek</td>\n",
" <td>French</td>\n",
" <td>0.990256</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4758</th>\n",
" <td>Fairbrace</td>\n",
" <td>English</td>\n",
" <td>Irish</td>\n",
" <td>0.987143</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3743</th>\n",
" <td>Spellmeyer</td>\n",
" <td>German</td>\n",
" <td>English</td>\n",
" <td>0.976530</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16468</th>\n",
" <td>Adamou</td>\n",
" <td>Greek</td>\n",
" <td>Arabic</td>\n",
" <td>0.973496</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3009</th>\n",
" <td>De la fuente</td>\n",
" <td>Spanish</td>\n",
" <td>French</td>\n",
" <td>0.969431</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3011</th>\n",
" <td>De leon</td>\n",
" <td>Spanish</td>\n",
" <td>French</td>\n",
" <td>0.964697</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3263</th>\n",
" <td>Boulos</td>\n",
" <td>Arabic</td>\n",
" <td>Greek</td>\n",
" <td>0.962321</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16513</th>\n",
" <td>Egonidis</td>\n",
" <td>Greek</td>\n",
" <td>Italian</td>\n",
" <td>0.954264</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2478</th>\n",
" <td>Suchanka</td>\n",
" <td>Czech</td>\n",
" <td>Japanese</td>\n",
" <td>0.949000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2515</th>\n",
" <td>Weichert</td>\n",
" <td>Czech</td>\n",
" <td>German</td>\n",
" <td>0.946457</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5476</th>\n",
" <td>Keene</td>\n",
" <td>English</td>\n",
" <td>Dutch</td>\n",
" <td>0.944270</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3511</th>\n",
" <td>Jaeger</td>\n",
" <td>German</td>\n",
" <td>Dutch</td>\n",
" <td>0.938891</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3174</th>\n",
" <td>Attia</td>\n",
" <td>Arabic</td>\n",
" <td>Italian</td>\n",
" <td>0.935905</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" name cl pred prob\n",
"2031 Schwarzenberg Dutch German 0.999774\n",
"16578 Malihoudis Greek Arabic 0.992311\n",
"16576 Louverdis Greek French 0.990256\n",
"4758 Fairbrace English Irish 0.987143\n",
"3743 Spellmeyer German English 0.976530\n",
"16468 Adamou Greek Arabic 0.973496\n",
"3009 De la fuente Spanish French 0.969431\n",
"3011 De leon Spanish French 0.964697\n",
"3263 Boulos Arabic Greek 0.962321\n",
"16513 Egonidis Greek Italian 0.954264\n",
"2478 Suchanka Czech Japanese 0.949000\n",
"2515 Weichert Czech German 0.946457\n",
"5476 Keene English Dutch 0.944270\n",
"3511 Jaeger German Dutch 0.938891\n",
"3174 Attia Arabic Italian 0.935905"
]
},
"execution_count": 222,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bigram_preds[bigram_preds.cl != bigram_preds.pred].sort_values('prob', ascending=False).head(15)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our very simple system does great on Japanese and Russian, but relatively poorly on Vietnamese where our data is most sparse (but still much better than random)."
]
},
{
"cell_type": "code",
"execution_count": 223,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"cl\n",
"Japanese 0.866667\n",
"Russian 0.733333\n",
"Polish 0.666667\n",
"Irish 0.666667\n",
"Dutch 0.633333\n",
"Italian 0.600000\n",
"Greek 0.533333\n",
"German 0.500000\n",
"English 0.500000\n",
"Spanish 0.466667\n",
"French 0.466667\n",
"Arabic 0.433333\n",
"Czech 0.400000\n",
"Chinese 0.400000\n",
"Korean 0.366667\n",
"Vietnamese 0.233333\n",
"Name: yes, dtype: float64"
]
},
"execution_count": 223,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(bigram_preds\n",
" .assign(yes=bigram_preds.cl == bigram_preds.pred)\n",
" .groupby('cl')\n",
" .yes\n",
" .mean()\n",
" .sort_values(ascending=False)"
]
},
{
"cell_type": "code",
"execution_count": 227,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import confusion_matrix"
]
},
{
"cell_type": "code",
"execution_count": 236,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(['Arabic', 'Irish', 'German', 'Dutch', ..., 'Russian', 'Polish', 'Irish', 'Italian'], dtype='<U10')"
]
},
"execution_count": 236,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bigram_pred"
]
},
{
"cell_type": "code",
"execution_count": 246,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[11, 1, 0, 6, ..., 0, 0, 2, 1],\n",
" [ 0, 18, 0, 1, ..., 2, 1, 2, 1],\n",
" [ 0, 1, 20, 1, ..., 0, 0, 0, 0],\n",
" [ 1, 0, 0, 26, ..., 1, 0, 0, 0],\n",
" ...,\n",
" [ 1, 2, 0, 1, ..., 15, 0, 1, 1],\n",
" [ 0, 2, 1, 1, ..., 1, 22, 0, 0],\n",
" [ 0, 2, 1, 1, ..., 0, 1, 16, 3],\n",
" [ 0, 3, 1, 0, ..., 1, 1, 4, 14]])"
]
},
"execution_count": 246,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cm = confusion_matrix(y[valid_idx], bigram_pred, labels=y.unique())\n",
"cm"
]
},
{
"cell_type": "code",
"execution_count": 201,
"metadata": {},
"outputs": [],
"source": [
"def plot_confusion_matrix(cm, classes,\n",
" normalize=False,\n",
" title='Confusion matrix',\n",
" cmap=plt.cm.Blues):\n",
" \"\"\"\n",
" This function prints and plots the confusion matrix.\n",
" Normalization can be applied by setting `normalize=True`.\n",
" \"\"\"\n",
" if normalize:\n",
" cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n",
" print(\"Normalized confusion matrix\")\n",
" else:\n",
" print('Confusion matrix, without normalization')\n",
"\n",
" plt.imshow(cm, interpolation='nearest', cmap=cmap)\n",
" plt.title(title)\n",
" plt.colorbar()\n",
" tick_marks = np.arange(len(classes))\n",
" plt.xticks(tick_marks, classes, rotation=90)\n",
" plt.yticks(tick_marks, classes)\n",
"\n",
" fmt = '.2f' if normalize else 'd'\n",
" thresh = cm.max() / 2.\n",
" for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n",
" plt.text(j, i, format(cm[i, j], fmt),\n",
" horizontalalignment=\"center\",\n",
" color=\"white\" if cm[i, j] > thresh else \"black\")\n",
"\n",
" plt.ylabel('True label')\n",
" plt.xlabel('Predicted label')\n",
" plt.tight_layout()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Vietnamese is often confused for Chinese (which makes sense) and Irish (which doesn't).\n",
"Korean is often confused for Japanese.\n",
"Spanish is often confused for Italian."
]
},
{
"cell_type": "code",
"execution_count": 250,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Confusion matrix, without normalization\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 864x864 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(12,12))\n",
"plot_confusion_matrix(cm, y.unique())"
]
},
{
"cell_type": "code",
"execution_count": 256,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>name</th>\n",
" <th>cl</th>\n",
" <th>pred</th>\n",
" <th>prob</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2907</th>\n",
" <td>Do</td>\n",
" <td>Vietnamese</td>\n",
" <td>Irish</td>\n",
" <td>0.176679</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2935</th>\n",
" <td>Ta</td>\n",
" <td>Vietnamese</td>\n",
" <td>Japanese</td>\n",
" <td>0.226875</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2924</th>\n",
" <td>Luc</td>\n",
" <td>Vietnamese</td>\n",
" <td>Vietnamese</td>\n",
" <td>0.253900</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2944</th>\n",
" <td>Ton</td>\n",
" <td>Vietnamese</td>\n",
" <td>Vietnamese</td>\n",
" <td>0.282859</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2930</th>\n",
" <td>Pho</td>\n",
" <td>Vietnamese</td>\n",
" <td>Dutch</td>\n",
" <td>0.296166</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2910</th>\n",
" <td>Ly</td>\n",
" <td>Vietnamese</td>\n",
" <td>Russian</td>\n",
" <td>0.298872</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2915</th>\n",
" <td>Doan</td>\n",
" <td>Vietnamese</td>\n",
" <td>Chinese</td>\n",
" <td>0.307318</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2916</th>\n",
" <td>Dam</td>\n",
" <td>Vietnamese</td>\n",
" <td>Arabic</td>\n",
" <td>0.325199</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2906</th>\n",
" <td>Dang</td>\n",
" <td>Vietnamese</td>\n",
" <td>Chinese</td>\n",
" <td>0.388393</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2900</th>\n",
" <td>Pham</td>\n",
" <td>Vietnamese</td>\n",
" <td>Arabic</td>\n",
" <td>0.396346</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2928</th>\n",
" <td>Nghiem</td>\n",
" <td>Vietnamese</td>\n",
" <td>English</td>\n",
" <td>0.428147</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2932</th>\n",
" <td>Quach</td>\n",
" <td>Vietnamese</td>\n",
" <td>Vietnamese</td>\n",
" <td>0.450314</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2922</th>\n",
" <td>Lac</td>\n",
" <td>Vietnamese</td>\n",
" <td>Irish</td>\n",
" <td>0.450851</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2938</th>\n",
" <td>Thi</td>\n",
" <td>Vietnamese</td>\n",
" <td>Chinese</td>\n",
" <td>0.455845</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2902</th>\n",
" <td>Hoang</td>\n",
" <td>Vietnamese</td>\n",
" <td>Korean</td>\n",
" <td>0.487704</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2927</th>\n",
" <td>Mach</td>\n",
" <td>Vietnamese</td>\n",
" <td>Irish</td>\n",
" <td>0.505350</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2917</th>\n",
" <td>Dao</td>\n",
" <td>Vietnamese</td>\n",
" <td>Irish</td>\n",
" <td>0.512267</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2923</th>\n",
" <td>Lieu</td>\n",
" <td>Vietnamese</td>\n",
" <td>French</td>\n",
" <td>0.515336</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2939</th>\n",
" <td>Than</td>\n",
" <td>Vietnamese</td>\n",
" <td>Chinese</td>\n",
" <td>0.530805</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2937</th>\n",
" <td>Thai</td>\n",
" <td>Vietnamese</td>\n",
" <td>Irish</td>\n",
" <td>0.580569</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" name cl pred prob\n",
"2907 Do Vietnamese Irish 0.176679\n",
"2935 Ta Vietnamese Japanese 0.226875\n",
"2924 Luc Vietnamese Vietnamese 0.253900\n",
"2944 Ton Vietnamese Vietnamese 0.282859\n",
"2930 Pho Vietnamese Dutch 0.296166\n",
"2910 Ly Vietnamese Russian 0.298872\n",
"2915 Doan Vietnamese Chinese 0.307318\n",
"2916 Dam Vietnamese Arabic 0.325199\n",
"2906 Dang Vietnamese Chinese 0.388393\n",
"2900 Pham Vietnamese Arabic 0.396346\n",
"2928 Nghiem Vietnamese English 0.428147\n",
"2932 Quach Vietnamese Vietnamese 0.450314\n",
"2922 Lac Vietnamese Irish 0.450851\n",
"2938 Thi Vietnamese Chinese 0.455845\n",
"2902 Hoang Vietnamese Korean 0.487704\n",
"2927 Mach Vietnamese Irish 0.505350\n",
"2917 Dao Vietnamese Irish 0.512267\n",
"2923 Lieu Vietnamese French 0.515336\n",
"2939 Than Vietnamese Chinese 0.530805\n",
"2937 Thai Vietnamese Irish 0.580569"
]
},
"execution_count": 256,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bigram_preds[bigram_preds.cl == 'Vietnamese'].sort_values('prob').head(20)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So our baseline is 53%. Let's see if we can do better with deep learning"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Deep Learning\n",
"## Build a Fastai Data Loader"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load in the dataframe and extract indexes for training, validation and balanced trainings."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv('names_clean.csv')\n",
"\n",
"valid_idx = df[df.valid].index\n",
"train_idx = df[~df.valid].index\n",
"\n",
"bal_idx = []\n",
"for k, v in zip(df.index, df.bal):\n",
" bal_idx += [k]*v"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As of December 2018 Fastai only has Word level tokenizers; we'll have to create our own letter tokenizer.\n",
"\n",
"The fastai library injects `BOS` markers (`xxbos`) at the start of every string; we'll have to parse them separately."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class LetterTokenizer(BaseTokenizer):\n",
" \"Character level tokenizer function.\"\n",
" def __init__(self, lang): pass\n",
" def tokenizer(self, t:str) -> List[str]:\n",
" out = []\n",
" i = 0\n",
" while i < len(t):\n",
" if t[i:].startswith(BOS):\n",
" out.append(BOS)\n",
" i += len(BOS)\n",
" else:\n",
" out.append(t[i])\n",
" i += 1\n",
" return out\n",
" \n",
" def add_special_cases(self, toks:Collection[str]): pass"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We create a vocab of all ASCII letters, and a character tokenizer that doesn't do any specific processing."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"itos = [UNK, BOS] + list(string.ascii_lowercase + \" -'\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"vocab=Vocab(itos)\n",
"tokenizer=Tokenizer(LetterTokenizer, pre_rules=[], post_rules=[])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can create a data pipeline using the `TextDataBunch.from_df` constructor.\n",
"\n",
"`mark_fields` puts and extra `xxfld` marker between each field of text. Since we only have 1 field this is unnecessary."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"train_df = df.iloc[train_idx, [0,2]]\n",
"valid_df = df.iloc[valid_idx, [0,2]]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cl</th>\n",
" <th>ascii_name</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Korean</td>\n",
" <td>ahn</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Korean</td>\n",
" <td>bang</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Korean</td>\n",
" <td>byon</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>Korean</td>\n",
" <td>gil</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>Korean</td>\n",
" <td>gu</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" cl ascii_name\n",
"0 Korean ahn\n",
"2 Korean bang\n",
"3 Korean byon\n",
"10 Korean gil\n",
"11 Korean gu"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"data = TextClasDataBunch.from_df(path='.', train_df=train_df, valid_df=valid_df,\n",
" tokenizer=tokenizer, vocab=vocab,\n",
" mark_fields=False)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table> <col width='90%'> <col width='10%'> <tr>\n",
" <th>text</th>\n",
" <th>target</th>\n",
" </tr>\n",
" <tr>\n",
" <th> v o n g r i m m e l s h a u s e n</th>\n",
" <th>German</th>\n",
" </tr>\n",
" <tr>\n",
" <th> m a c e a c h t h i g h e a r n a</th>\n",
" <th>Irish</th>\n",
" </tr>\n",
" <tr>\n",
" <th> c h k h a r t i s h v i l i</th>\n",
" <th>Russian</th>\n",
" </tr>\n",
" <tr>\n",
" <th> t z e h m i s t r e n k o</th>\n",
" <th>Russian</th>\n",
" </tr>\n",
" <tr>\n",
" <th> c h e p t y g m a s h e v</th>\n",
" <th>Russian</th>\n",
" </tr>\n",
"</table>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data.show_batch()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Or we can create it using data block API.\n",
"This uses the `processors` to tokenize and numericalize the input."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"processors = [TokenizeProcessor(tokenizer=tokenizer, mark_fields=False),\n",
" NumericalizeProcessor(vocab=vocab)]"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"data = (TextList\n",
" .from_df(df, \n",
" cols=[2], \n",
" processor=processors)\n",
" .split_by_idxs(train_idx=train_idx, valid_idx=valid_idx)\n",
" .label_from_df(cols=0)\n",
" .databunch(bs=32))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table> <col width='90%'> <col width='10%'> <tr>\n",
" <th>text</th>\n",
" <th>target</th>\n",
" </tr>\n",
" <tr>\n",
" <th> v o n g r i m m e l s h a u s e n</th>\n",
" <th>German</th>\n",
" </tr>\n",
" <tr>\n",
" <th> p a r a s k e v o p o u l o s</th>\n",
" <th>Greek</th>\n",
" </tr>\n",
" <tr>\n",
" <th> d z h a v a h i s h v i l i</th>\n",
" <th>Russian</th>\n",
" </tr>\n",
" <tr>\n",
" <th> s h a h n a z a r y a n t s</th>\n",
" <th>Russian</th>\n",
" </tr>\n",
" <tr>\n",
" <th> m o g i l n i c h e n k o</th>\n",
" <th>Russian</th>\n",
" </tr>\n",
"</table>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data.show_batch()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sanity Checking"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Counter({'Korean': 30,\n",
" 'Italian': 30,\n",
" 'Polish': 30,\n",
" 'Japanese': 30,\n",
" 'Dutch': 30,\n",
" 'Czech': 30,\n",
" 'Irish': 30,\n",
" 'Chinese': 30,\n",
" 'Vietnamese': 30,\n",
" 'Spanish': 30,\n",
" 'Arabic': 30,\n",
" 'German': 30,\n",
" 'English': 30,\n",
" 'Russian': 30,\n",
" 'Greek': 30,\n",
" 'French': 30})"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Counter(_.obj for _ in data.valid_ds.y)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('Russian', 9232),\n",
" ('English', 3329),\n",
" ('Japanese', 952),\n",
" ('Italian', 630),\n",
" ('German', 548),\n",
" ('Czech', 434),\n",
" ('Dutch', 214),\n",
" ('Spanish', 182),\n",
" ('French', 180),\n",
" ('Chinese', 170),\n",
" ('Greek', 162),\n",
" ('Irish', 134),\n",
" ('Polish', 93),\n",
" ('Arabic', 73),\n",
" ('Korean', 31),\n",
" ('Vietnamese', 25)]"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Counter(_.obj for _ in data.train_ds.y).most_common()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check no text is both in Validation and Training"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"valid_set = set(_.text for _ in data.valid_ds.x)\n",
"for _ in data.train_ds.x:\n",
" assert _.text not in valid_set, _.text"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Examine a minibatch"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"trainiter = iter(data.train_dl)\n",
"batch, cl = next(trainiter)\n",
"batch2, cl2 = next(trainiter)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([ 6, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13],\n",
" device='cuda:0'), 16)"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cl, len(cl)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([20, 16])"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"batch.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The first 22 letters run *down* the batch backpadded by `BOS`; we have 16 names across.\n",
"\n",
"Somehow it looks like we also have an extra space at the beginning of each name that wasn't in the input data.\n",
"\n",
"(Note this is different to what the fastai wrappers will give you; they concatenate the data and split it into 16 chunks)."
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" <th>1</th>\n",
" <th>2</th>\n",
" <th>3</th>\n",
" <th>4</th>\n",
" <th>5</th>\n",
" <th>6</th>\n",
" <th>7</th>\n",
" <th>8</th>\n",
" <th>9</th>\n",
" <th>10</th>\n",
" <th>11</th>\n",
" <th>12</th>\n",
" <th>13</th>\n",
" <th>14</th>\n",
" <th>15</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td></td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>v</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>o</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>n</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>g</td>\n",
" <td>t</td>\n",
" <td>l</td>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>r</td>\n",
" <td>c</td>\n",
" <td>e</td>\n",
" <td>t</td>\n",
" <td>m</td>\n",
" <td>b</td>\n",
" <td>c</td>\n",
" <td>z</td>\n",
" <td>b</td>\n",
" <td>p</td>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>i</td>\n",
" <td>h</td>\n",
" <td>i</td>\n",
" <td>c</td>\n",
" <td>i</td>\n",
" <td>a</td>\n",
" <td>h</td>\n",
" <td>h</td>\n",
" <td>a</td>\n",
" <td>a</td>\n",
" <td>g</td>\n",
" <td>s</td>\n",
" <td>a</td>\n",
" <td>b</td>\n",
" <td>v</td>\n",
" <td>v</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>m</td>\n",
" <td>a</td>\n",
" <td>h</td>\n",
" <td>h</td>\n",
" <td>n</td>\n",
" <td>k</td>\n",
" <td>a</td>\n",
" <td>e</td>\n",
" <td>h</td>\n",
" <td>t</td>\n",
" <td>r</td>\n",
" <td>h</td>\n",
" <td>w</td>\n",
" <td>a</td>\n",
" <td>y</td>\n",
" <td>i</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>e</td>\n",
" <td>t</td>\n",
" <td>e</td>\n",
" <td>h</td>\n",
" <td>a</td>\n",
" <td>h</td>\n",
" <td>t</td>\n",
" <td>o</td>\n",
" <td>h</td>\n",
" <td>i</td>\n",
" <td>s</td>\n",
" <td>n</td>\n",
" <td>o</td>\n",
" <td>h</td>\n",
" <td>c</td>\n",
" <td>c</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>l</td>\n",
" <td>o</td>\n",
" <td>n</td>\n",
" <td>l</td>\n",
" <td>z</td>\n",
" <td>t</td>\n",
" <td>o</td>\n",
" <td>k</td>\n",
" <td>i</td>\n",
" <td>o</td>\n",
" <td>h</td>\n",
" <td>d</td>\n",
" <td>r</td>\n",
" <td>t</td>\n",
" <td>h</td>\n",
" <td>h</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>s</td>\n",
" <td>r</td>\n",
" <td>b</td>\n",
" <td>a</td>\n",
" <td>e</td>\n",
" <td>a</td>\n",
" <td>r</td>\n",
" <td>h</td>\n",
" <td>v</td>\n",
" <td>r</td>\n",
" <td>e</td>\n",
" <td>e</td>\n",
" <td>k</td>\n",
" <td>i</td>\n",
" <td>e</td>\n",
" <td>e</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>h</td>\n",
" <td>i</td>\n",
" <td>e</td>\n",
" <td>k</td>\n",
" <td>t</td>\n",
" <td>n</td>\n",
" <td>i</td>\n",
" <td>o</td>\n",
" <td>a</td>\n",
" <td>k</td>\n",
" <td>l</td>\n",
" <td>r</td>\n",
" <td>h</td>\n",
" <td>g</td>\n",
" <td>s</td>\n",
" <td>p</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td>a</td>\n",
" <td>z</td>\n",
" <td>r</td>\n",
" <td>o</td>\n",
" <td>d</td>\n",
" <td>o</td>\n",
" <td>z</td>\n",
" <td>v</td>\n",
" <td>n</td>\n",
" <td>o</td>\n",
" <td>e</td>\n",
" <td>o</td>\n",
" <td>a</td>\n",
" <td>a</td>\n",
" <td>l</td>\n",
" <td>o</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>u</td>\n",
" <td>h</td>\n",
" <td>g</td>\n",
" <td>v</td>\n",
" <td>i</td>\n",
" <td>w</td>\n",
" <td>h</td>\n",
" <td>t</td>\n",
" <td>d</td>\n",
" <td>v</td>\n",
" <td>v</td>\n",
" <td>v</td>\n",
" <td>n</td>\n",
" <td>r</td>\n",
" <td>a</td>\n",
" <td>l</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <td>s</td>\n",
" <td>s</td>\n",
" <td>s</td>\n",
" <td>s</td>\n",
" <td>n</td>\n",
" <td>s</td>\n",
" <td>s</td>\n",
" <td>s</td>\n",
" <td>z</td>\n",
" <td>s</td>\n",
" <td>s</td>\n",
" <td>i</td>\n",
" <td>o</td>\n",
" <td>e</td>\n",
" <td>v</td>\n",
" <td>s</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <td>e</td>\n",
" <td>k</td>\n",
" <td>k</td>\n",
" <td>k</td>\n",
" <td>o</td>\n",
" <td>k</td>\n",
" <td>k</td>\n",
" <td>e</td>\n",
" <td>h</td>\n",
" <td>k</td>\n",
" <td>k</td>\n",
" <td>c</td>\n",
" <td>f</td>\n",
" <td>e</td>\n",
" <td>o</td>\n",
" <td>k</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <td>n</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>v</td>\n",
" <td>i</td>\n",
" <td>y</td>\n",
" <td>v</td>\n",
" <td>i</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>h</td>\n",
" <td>f</td>\n",
" <td>v</td>\n",
" <td>v</td>\n",
" <td>y</td>\n",
" </tr>\n",
" <tr>\n",
" <th>category</th>\n",
" <td>German</td>\n",
" <td>Russian</td>\n",
" <td>Russian</td>\n",
" <td>Russian</td>\n",
" <td>Russian</td>\n",
" <td>Russian</td>\n",
" <td>Russian</td>\n",
" <td>Russian</td>\n",
" <td>Russian</td>\n",
" <td>Russian</td>\n",
" <td>Russian</td>\n",
" <td>Russian</td>\n",
" <td>Russian</td>\n",
" <td>Russian</td>\n",
" <td>Russian</td>\n",
" <td>Russian</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>21 rows × 16 columns</p>\n",
"</div>"
],
"text/plain": [
" 0 1 2 3 4 5 6 \\\n",
"0 xxbos xxbos xxbos xxbos xxbos xxbos xxbos \n",
"1 xxbos xxbos xxbos xxbos xxbos xxbos \n",
"2 v xxbos xxbos xxbos xxbos xxbos xxbos \n",
"3 o xxbos xxbos xxbos xxbos xxbos xxbos \n",
"4 n xxbos xxbos xxbos xxbos xxbos xxbos \n",
"5 xxbos xxbos xxbos xxbos \n",
"6 g t l \n",
"7 r c e t m b c \n",
"8 i h i c i a h \n",
"9 m a h h n k a \n",
"... ... ... ... ... ... ... ... \n",
"11 e t e h a h t \n",
"12 l o n l z t o \n",
"13 s r b a e a r \n",
"14 h i e k t n i \n",
"15 a z r o d o z \n",
"16 u h g v i w h \n",
"17 s s s s n s s \n",
"18 e k k k o k k \n",
"19 n y y y v i y \n",
"category German Russian Russian Russian Russian Russian Russian \n",
"\n",
" 7 8 9 10 11 12 13 \\\n",
"0 xxbos xxbos xxbos xxbos xxbos xxbos xxbos \n",
"1 xxbos xxbos xxbos xxbos xxbos xxbos xxbos \n",
"2 xxbos xxbos xxbos xxbos xxbos xxbos xxbos \n",
"3 xxbos xxbos xxbos xxbos xxbos xxbos xxbos \n",
"4 xxbos xxbos xxbos xxbos xxbos xxbos xxbos \n",
"5 xxbos xxbos xxbos xxbos xxbos xxbos xxbos \n",
"6 xxbos xxbos xxbos xxbos \n",
"7 z b p \n",
"8 h a a g s a b \n",
"9 e h t r h w a \n",
"... ... ... ... ... ... ... ... \n",
"11 o h i s n o h \n",
"12 k i o h d r t \n",
"13 h v r e e k i \n",
"14 o a k l r h g \n",
"15 v n o e o a a \n",
"16 t d v v v n r \n",
"17 s z s s i o e \n",
"18 e h k k c f e \n",
"19 v i y y h f v \n",
"category Russian Russian Russian Russian Russian Russian Russian \n",
"\n",
" 14 15 \n",
"0 xxbos xxbos \n",
"1 xxbos xxbos \n",
"2 xxbos xxbos \n",
"3 xxbos xxbos \n",
"4 xxbos xxbos \n",
"5 xxbos xxbos \n",
"6 xxbos xxbos \n",
"7 \n",
"8 v v \n",
"9 y i \n",
"... ... ... \n",
"11 c c \n",
"12 h h \n",
"13 e e \n",
"14 s p \n",
"15 l o \n",
"16 a l \n",
"17 v s \n",
"18 o k \n",
"19 v y \n",
"category Russian Russian \n",
"\n",
"[21 rows x 16 columns]"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.options.display.max_columns = 100\n",
"(pd\n",
" .DataFrame([[vocab.itos[y] for y in x] for x in batch])\n",
" .T\n",
" .assign(category=[data.classes[_] for _ in cl])\n",
" .T)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['xxbos', ' ', 'a', 'h', 'n']"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[vocab.itos[_] for _ in data.train_ds[0][0].data]"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['A', 'h', 'n']"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(df.iloc[0,1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note the length of strings varies between batches."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" <th>1</th>\n",
" <th>2</th>\n",
" <th>3</th>\n",
" <th>4</th>\n",
" <th>5</th>\n",
" <th>6</th>\n",
" <th>7</th>\n",
" <th>8</th>\n",
" <th>9</th>\n",
" <th>10</th>\n",
" <th>11</th>\n",
" <th>12</th>\n",
" <th>13</th>\n",
" <th>14</th>\n",
" <th>15</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" <td>xxbos</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>b</td>\n",
" <td>m</td>\n",
" <td>r</td>\n",
" <td>f</td>\n",
" <td>m</td>\n",
" <td>b</td>\n",
" <td>d</td>\n",
" <td>j</td>\n",
" <td>m</td>\n",
" <td>b</td>\n",
" <td>k</td>\n",
" <td>u</td>\n",
" <td>m</td>\n",
" <td>m</td>\n",
" <td>a</td>\n",
" <td>b</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>e</td>\n",
" <td>c</td>\n",
" <td>i</td>\n",
" <td>e</td>\n",
" <td>a</td>\n",
" <td>a</td>\n",
" <td>e</td>\n",
" <td>e</td>\n",
" <td>o</td>\n",
" <td>a</td>\n",
" <td>i</td>\n",
" <td>f</td>\n",
" <td>a</td>\n",
" <td>e</td>\n",
" <td>n</td>\n",
" <td>a</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>n</td>\n",
" <td>g</td>\n",
" <td>d</td>\n",
" <td>r</td>\n",
" <td>s</td>\n",
" <td>j</td>\n",
" <td>m</td>\n",
" <td>f</td>\n",
" <td>r</td>\n",
" <td>l</td>\n",
" <td>n</td>\n",
" <td>i</td>\n",
" <td>k</td>\n",
" <td>a</td>\n",
" <td>s</td>\n",
" <td>b</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>i</td>\n",
" <td>r</td>\n",
" <td>g</td>\n",
" <td>r</td>\n",
" <td>a</td>\n",
" <td>e</td>\n",
" <td>a</td>\n",
" <td>f</td>\n",
" <td>a</td>\n",
" <td>b</td>\n",
" <td>c</td>\n",
" <td>m</td>\n",
" <td>i</td>\n",
" <td>d</td>\n",
" <td>e</td>\n",
" <td>u</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>t</td>\n",
" <td>o</td>\n",
" <td>w</td>\n",
" <td>a</td>\n",
" <td>o</td>\n",
" <td>n</td>\n",
" <td>k</td>\n",
" <td>e</td>\n",
" <td>n</td>\n",
" <td>o</td>\n",
" <td>h</td>\n",
" <td>k</td>\n",
" <td>o</td>\n",
" <td>h</td>\n",
" <td>l</td>\n",
" <td>r</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>e</td>\n",
" <td>r</td>\n",
" <td>a</td>\n",
" <td>r</td>\n",
" <td>k</td>\n",
" <td>o</td>\n",
" <td>i</td>\n",
" <td>r</td>\n",
" <td>d</td>\n",
" <td>n</td>\n",
" <td>i</td>\n",
" <td>i</td>\n",
" <td>k</td>\n",
" <td>r</td>\n",
" <td>m</td>\n",
" <td>i</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>z</td>\n",
" <td>y</td>\n",
" <td>y</td>\n",
" <td>o</td>\n",
" <td>a</td>\n",
" <td>v</td>\n",
" <td>s</td>\n",
" <td>s</td>\n",
" <td>i</td>\n",
" <td>i</td>\n",
" <td>n</td>\n",
" <td>n</td>\n",
" <td>a</td>\n",
" <td>a</td>\n",
" <td>i</td>\n",
" <td>n</td>\n",
" </tr>\n",
" <tr>\n",
" <th>category</th>\n",
" <td>Spanish</td>\n",
" <td>English</td>\n",
" <td>English</td>\n",
" <td>Italian</td>\n",
" <td>Japanese</td>\n",
" <td>Russian</td>\n",
" <td>Greek</td>\n",
" <td>English</td>\n",
" <td>Italian</td>\n",
" <td>Italian</td>\n",
" <td>English</td>\n",
" <td>Russian</td>\n",
" <td>Japanese</td>\n",
" <td>Irish</td>\n",
" <td>Italian</td>\n",
" <td>Russian</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" 0 1 2 3 4 5 6 \\\n",
"0 xxbos xxbos xxbos xxbos xxbos xxbos xxbos \n",
"1 \n",
"2 b m r f m b d \n",
"3 e c i e a a e \n",
"4 n g d r s j m \n",
"5 i r g r a e a \n",
"6 t o w a o n k \n",
"7 e r a r k o i \n",
"8 z y y o a v s \n",
"category Spanish English English Italian Japanese Russian Greek \n",
"\n",
" 7 8 9 10 11 12 13 \\\n",
"0 xxbos xxbos xxbos xxbos xxbos xxbos xxbos \n",
"1 \n",
"2 j m b k u m m \n",
"3 e o a i f a e \n",
"4 f r l n i k a \n",
"5 f a b c m i d \n",
"6 e n o h k o h \n",
"7 r d n i i k r \n",
"8 s i i n n a a \n",
"category English Italian Italian English Russian Japanese Irish \n",
"\n",
" 14 15 \n",
"0 xxbos xxbos \n",
"1 \n",
"2 a b \n",
"3 n a \n",
"4 s b \n",
"5 e u \n",
"6 l r \n",
"7 m i \n",
"8 i n \n",
"category Italian Russian "
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(pd\n",
" .DataFrame([[vocab.itos[y] for y in x] for x in batch2])\n",
" .T\n",
" .assign(category=[data.classes[_] for _ in cl2])\n",
" .T)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'xxbos b e n i t e z'"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vocab.textify(batch2[:,0])"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table> <col width='90%'> <col width='10%'> <tr>\n",
" <th>text</th>\n",
" <th>target</th>\n",
" </tr>\n",
" <tr>\n",
" <th> c h r y s a n t h o p o u l o s</th>\n",
" <th>Greek</th>\n",
" </tr>\n",
" <tr>\n",
" <th> v o n i n g e r s l e b e n</th>\n",
" <th>German</th>\n",
" </tr>\n",
" <tr>\n",
" <th> s c h w a r z e n b e r g</th>\n",
" <th>Dutch</th>\n",
" </tr>\n",
" <tr>\n",
" <th> d e s a u v e t e r r e</th>\n",
" <th>French</th>\n",
" </tr>\n",
" <tr>\n",
" <th> a r e c h a v a l e t a</th>\n",
" <th>Spanish</th>\n",
" </tr>\n",
"</table>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data.show_batch(ds_type=DatasetType.Valid)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## One Hot Encoding\n",
"\n",
"The torch nn.RNN expects the data to be one hot encoded"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"one_hot = torch.eye(len(vocab.itos))"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],\n",
"\n",
" [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"one_hot[batch][:2]"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([20, 16, 31])"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"one_hot[batch].shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here's how we could do it without storing the one_hot matrix in memory."
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"def one_hot_fly(y, length=len(vocab.itos)):\n",
" length = len(vocab.itos)\n",
" shape = list(y.shape)\n",
" assert len(shape) == 2\n",
" tensor = torch.zeros(shape + [length])\n",
" for i,row in enumerate(y):\n",
" for j, val in enumerate(row):\n",
" tensor[i][j][val] = 1.\n",
" return tensor"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(1, dtype=torch.uint8)"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(one_hot[batch] == one_hot_fly(batch)).all()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Using matrix operations is ~250 times faster at this size than the double for loop."
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"36.1 µs ± 2.61 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
"8.91 ms ± 210 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%timeit one_hot[batch]\n",
"%timeit one_hot_fly(batch)\n",
"None"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Fitting a model"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(31, 16)"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"n_letters = len(vocab.itos)\n",
"n_hidden = 128\n",
"n_output = df.cl.nunique()\n",
"n_letters, n_output"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We use an RNN to take our sequence of letters in and calculate the hidden state"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"rnn = nn.RNN(input_size=n_letters,\n",
" hidden_size=n_hidden,\n",
" num_layers=1,\n",
" nonlinearity='relu',\n",
" dropout=0.)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([20, 16, 128]), torch.Size([1, 16, 128]))"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output, hidden = rnn(one_hot[batch])\n",
"output.shape, hidden.shape"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"lo = nn.Linear(n_hidden, n_output)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"preds = lo(output)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([20, 16, 16])"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preds.shape"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 6, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13],\n",
" device='cuda:0')"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cl"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8])"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nn.functional.softmax(preds[-1], dim=1).argmax(dim=1)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"one_hot = torch.eye(len(vocab.itos))"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"class MyLetterRNN(nn.Module):\n",
" def __init__(self, dropout=0., n_layers=1, n_input=n_letters, n_hidden=n_hidden, n_output=n_output):\n",
" super().__init__()\n",
" self.one_hot = torch.eye(n_letters).cuda()\n",
" self.rnn = nn.RNN(input_size=n_letters,\n",
" hidden_size=n_hidden,\n",
" num_layers=n_layers,\n",
" nonlinearity='relu',\n",
" dropout=dropout)\n",
" self.lo = nn.Linear(n_hidden, n_output)\n",
" \n",
" def forward(self, input):\n",
" rnn, _ = self.rnn(self.one_hot[input])\n",
" out = self.lo(rnn)\n",
" return out[-1]"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"rnn = MyLetterRNN().cuda()"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([6, 6, 6, 6, 1, 6, 6, 1, 6, 6, 6, 1, 6, 1, 1, 6], device='cuda:0'),\n",
" tensor([ 6, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13],\n",
" device='cuda:0'))"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out = rnn(batch)\n",
"out.argmax(dim=1), cl"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Fit the model"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(2.7832, device='cuda:0', grad_fn=<NllLossBackward>)"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"F.cross_entropy(out, cl)"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(data, rnn, loss_func=F.cross_entropy, metrics=[accuracy])"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
]
}
],
"source": [
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.recorder.plot()"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:35 <p><table style='width:300px; margin-bottom:10px'>\n",
" <tr>\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <th>0.811252</th>\n",
" <th>2.653636</th>\n",
" <th>0.260417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <th>0.928508</th>\n",
" <th>3.329767</th>\n",
" <th>0.216667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <th>0.830531</th>\n",
" <th>3.436638</th>\n",
" <th>0.218750</th>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <th>0.947136</th>\n",
" <th>3.056552</th>\n",
" <th>0.202083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <th>0.878935</th>\n",
" <th>3.361734</th>\n",
" <th>0.210417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <th>0.818984</th>\n",
" <th>3.208372</th>\n",
" <th>0.214583</th>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <th>0.811538</th>\n",
" <th>2.896590</th>\n",
" <th>0.252083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <th>0.745542</th>\n",
" <th>3.237130</th>\n",
" <th>0.283333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <th>0.753505</th>\n",
" <th>2.819807</th>\n",
" <th>0.302083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <th>0.763112</th>\n",
" <th>2.878011</th>\n",
" <th>0.297917</th>\n",
" </tr>\n",
"</table>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(10, max_lr=3e-2)"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.lr_find()\n",
"learn.recorder.plot()"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
"learn.save('char_rnn_1')"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:17 <p><table style='width:300px; margin-bottom:10px'>\n",
" <tr>\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <th>0.696038</th>\n",
" <th>2.910773</th>\n",
" <th>0.304167</th>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <th>0.734545</th>\n",
" <th>2.814250</th>\n",
" <th>0.306250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <th>0.634951</th>\n",
" <th>2.827829</th>\n",
" <th>0.295833</th>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <th>0.636780</th>\n",
" <th>2.758662</th>\n",
" <th>0.312500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <th>0.696148</th>\n",
" <th>2.838843</th>\n",
" <th>0.312500</th>\n",
" </tr>\n",
"</table>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(5, 3e-3)"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [],
"source": [
"learn.save('char_rnn_1_final')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is abysmal; 31% is much worse than 52% from the simple Naive Bayes bigram model.\n",
"\n",
"Does it improve if we add another layer?"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(data, MyLetterRNN(n_layers=2), loss_func=F.cross_entropy, metrics=[accuracy])"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.lr_find()\n",
"learn.recorder.plot()"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/html": [
"Total time: 01:34 <p><table style='width:300px; margin-bottom:10px'>\n",
" <tr>\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <th>0.929248</th>\n",
" <th>3.101529</th>\n",
" <th>0.189583</th>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <th>0.695901</th>\n",
" <th>2.869615</th>\n",
" <th>0.250000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <th>0.745567</th>\n",
" <th>2.520683</th>\n",
" <th>0.316667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <th>0.620927</th>\n",
" <th>3.530135</th>\n",
" <th>0.262500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <th>0.742575</th>\n",
" <th>2.512531</th>\n",
" <th>0.318750</th>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <th>0.723677</th>\n",
" <th>2.616584</th>\n",
" <th>0.343750</th>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <th>0.839355</th>\n",
" <th>2.454891</th>\n",
" <th>0.335417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <th>0.817186</th>\n",
" <th>2.794391</th>\n",
" <th>0.291667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <th>0.653000</th>\n",
" <th>2.695168</th>\n",
" <th>0.302083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <th>0.683367</th>\n",
" <th>2.637764</th>\n",
" <th>0.358333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <th>0.611877</th>\n",
" <th>2.308675</th>\n",
" <th>0.333333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <th>0.586979</th>\n",
" <th>2.296229</th>\n",
" <th>0.352083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <th>0.611386</th>\n",
" <th>2.224956</th>\n",
" <th>0.381250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <th>0.580687</th>\n",
" <th>2.247524</th>\n",
" <th>0.383333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <th>0.512482</th>\n",
" <th>2.244857</th>\n",
" <th>0.387500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <th>0.516693</th>\n",
" <th>2.303736</th>\n",
" <th>0.412500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <th>0.409016</th>\n",
" <th>2.413911</th>\n",
" <th>0.412500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <th>0.435291</th>\n",
" <th>2.442951</th>\n",
" <th>0.422917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <th>0.386392</th>\n",
" <th>2.507006</th>\n",
" <th>0.425000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <th>0.352908</th>\n",
" <th>2.518786</th>\n",
" <th>0.433333</th>\n",
" </tr>\n",
"</table>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(20, max_lr=1e-2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It looks like the fit has converged, again at a much worse result than our Naive Bayes bigrams.\n",
"\n",
"But that was trained using a balanced dataset; maybe that will help with RNNs too."
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.recorder.plot_losses()"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
"learn.save('char_rnn_2_p0')"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('English', 141),\n",
" ('Russian', 97),\n",
" ('Chinese', 54),\n",
" ('Italian', 43),\n",
" ('Japanese', 38),\n",
" ('Greek', 25),\n",
" ('German', 22),\n",
" ('Czech', 12),\n",
" ('French', 11),\n",
" ('Dutch', 11),\n",
" ('Spanish', 10),\n",
" ('Polish', 10),\n",
" ('Korean', 4),\n",
" ('Vietnamese', 2)]"
]
},
"execution_count": 76,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prob, targ = learn.get_preds()\n",
"Counter(data.classes[_.item()] for _ in prob.argmax(dim=1)).most_common()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Rebalancing\n",
"### Less is more\n",
"\n",
"Even though the balanced set is a subset of the training set (and throws away a lot of data), the model performs much better on the balanced validation set with it.\n",
"\n",
"This is because on the whole training set heuristics like \"when in doubt, guess Russian/English\" and \"it's almost never Vietnamese\" are good, but are terrible on our validation set."
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [],
"source": [
"data = (TextList\n",
" .from_df(df, \n",
" cols=[2], \n",
" processor=processors)\n",
" .split_by_idxs(train_idx=bal_idx, valid_idx=valid_idx)\n",
" .label_from_df(cols=0)\n",
" .databunch(bs=1024))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sanity Checking"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Counter({'Korean': 30,\n",
" 'Italian': 30,\n",
" 'Polish': 30,\n",
" 'Japanese': 30,\n",
" 'Dutch': 30,\n",
" 'Czech': 30,\n",
" 'Irish': 30,\n",
" 'Chinese': 30,\n",
" 'Vietnamese': 30,\n",
" 'Spanish': 30,\n",
" 'Arabic': 30,\n",
" 'German': 30,\n",
" 'English': 30,\n",
" 'Russian': 30,\n",
" 'Greek': 30,\n",
" 'French': 30})"
]
},
"execution_count": 78,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Counter(_.obj for _ in data.valid_ds.y)"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('Korean', 500),\n",
" ('Italian', 500),\n",
" ('Polish', 500),\n",
" ('Japanese', 500),\n",
" ('Dutch', 500),\n",
" ('Czech', 500),\n",
" ('Irish', 500),\n",
" ('Chinese', 500),\n",
" ('Vietnamese', 500),\n",
" ('Spanish', 500),\n",
" ('Arabic', 500),\n",
" ('German', 500),\n",
" ('English', 500),\n",
" ('Russian', 500),\n",
" ('Greek', 500),\n",
" ('French', 500)]"
]
},
"execution_count": 79,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Counter(_.obj for _ in data.train_ds.y).most_common()"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" </tr>\n",
" <tr>\n",
" <th>y</th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Russian</th>\n",
" <td>486</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>English</th>\n",
" <td>459</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Japanese</th>\n",
" <td>383</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Italian</th>\n",
" <td>357</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>German</th>\n",
" <td>330</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Czech</th>\n",
" <td>295</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Dutch</th>\n",
" <td>195</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>French</th>\n",
" <td>172</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Spanish</th>\n",
" <td>170</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Chinese</th>\n",
" <td>158</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Greek</th>\n",
" <td>153</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Irish</th>\n",
" <td>129</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Polish</th>\n",
" <td>93</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Arabic</th>\n",
" <td>73</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Korean</th>\n",
" <td>31</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Vietnamese</th>\n",
" <td>25</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" x y\n",
"y \n",
"Russian 486 1\n",
"English 459 1\n",
"Japanese 383 1\n",
"Italian 357 1\n",
"German 330 1\n",
"Czech 295 1\n",
"Dutch 195 1\n",
"French 172 1\n",
"Spanish 170 1\n",
"Chinese 158 1\n",
"Greek 153 1\n",
"Irish 129 1\n",
"Polish 93 1\n",
"Arabic 73 1\n",
"Korean 31 1\n",
"Vietnamese 25 1"
]
},
"execution_count": 80,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(pd.DataFrame({'x': [_.text for _ in data.train_ds.x], 'y': [_.obj for _ in data.train_ds.y]})\n",
" .groupby('y')\n",
" .nunique()\n",
" .sort_values('x', ascending=False))"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [],
"source": [
"valid_set = set(_.text for _ in data.valid_ds.x)\n",
"for _ in data.train_ds.x:\n",
" assert _.text not in valid_set, _.text"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Fitting"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(data, MyLetterRNN(n_layers=2), loss_func=F.cross_entropy, metrics=[accuracy])"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.lr_find()\n",
"learn.recorder.plot()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that our balanced dataset is about half the size of our training dataset. Useful to keep in mind when comparing number of epochs and runtime."
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2.048625"
]
},
"execution_count": 85,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(train_idx) / len(bal_idx)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We only get around ~51% accuracy on a balanced test set (similar to the Naive Bayes)"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:12 <p><table style='width:300px; margin-bottom:10px'>\n",
" <tr>\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <th>2.773524</th>\n",
" <th>2.755105</th>\n",
" <th>0.256250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <th>2.756307</th>\n",
" <th>2.651738</th>\n",
" <th>0.252083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <th>2.653947</th>\n",
" <th>2.520690</th>\n",
" <th>0.222917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <th>2.586744</th>\n",
" <th>2.220565</th>\n",
" <th>0.277083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <th>2.418116</th>\n",
" <th>2.046903</th>\n",
" <th>0.358333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <th>2.238451</th>\n",
" <th>2.456228</th>\n",
" <th>0.300000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <th>2.139630</th>\n",
" <th>1.903009</th>\n",
" <th>0.404167</th>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <th>1.966441</th>\n",
" <th>1.928596</th>\n",
" <th>0.441667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <th>1.810611</th>\n",
" <th>1.840151</th>\n",
" <th>0.439583</th>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <th>1.651867</th>\n",
" <th>1.802120</th>\n",
" <th>0.445833</th>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <th>1.560050</th>\n",
" <th>1.871820</th>\n",
" <th>0.450000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <th>1.477005</th>\n",
" <th>1.901750</th>\n",
" <th>0.479167</th>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <th>1.461073</th>\n",
" <th>2.092149</th>\n",
" <th>0.406250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <th>1.403493</th>\n",
" <th>2.019508</th>\n",
" <th>0.462500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <th>1.279949</th>\n",
" <th>1.945947</th>\n",
" <th>0.470833</th>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <th>1.140172</th>\n",
" <th>2.048897</th>\n",
" <th>0.504167</th>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <th>0.996752</th>\n",
" <th>2.194067</th>\n",
" <th>0.472917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <th>0.863288</th>\n",
" <th>2.319874</th>\n",
" <th>0.500000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <th>0.765463</th>\n",
" <th>2.229700</th>\n",
" <th>0.491667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <th>0.672631</th>\n",
" <th>2.348388</th>\n",
" <th>0.516667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>21</th>\n",
" <th>0.577968</th>\n",
" <th>2.426682</th>\n",
" <th>0.506250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>22</th>\n",
" <th>0.488059</th>\n",
" <th>2.629519</th>\n",
" <th>0.508333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>23</th>\n",
" <th>0.406027</th>\n",
" <th>2.711037</th>\n",
" <th>0.520833</th>\n",
" </tr>\n",
" <tr>\n",
" <th>24</th>\n",
" <th>0.335120</th>\n",
" <th>2.838989</th>\n",
" <th>0.512500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>25</th>\n",
" <th>0.274594</th>\n",
" <th>2.929183</th>\n",
" <th>0.510417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>26</th>\n",
" <th>0.225410</th>\n",
" <th>3.002161</th>\n",
" <th>0.506250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <th>0.185401</th>\n",
" <th>3.071880</th>\n",
" <th>0.506250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>28</th>\n",
" <th>0.154014</th>\n",
" <th>3.097990</th>\n",
" <th>0.506250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>29</th>\n",
" <th>0.130493</th>\n",
" <th>3.111022</th>\n",
" <th>0.510417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>30</th>\n",
" <th>0.113012</th>\n",
" <th>3.112732</th>\n",
" <th>0.510417</th>\n",
" </tr>\n",
"</table>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(30, max_lr=3e-2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It's starting to overfit and so could perhaps do with some regularization."
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.recorder.plot_losses()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This model is a little worse in accuracy than the Naive Bayes Bigram model.\n",
"\n",
"But our Neural Network is much more computationally intense and has about 4 times as many parameters!"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1056"
]
},
"execution_count": 89,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sum(len(_) for _ in learn.model.parameters())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Regularisation: Dropout\n",
"\n",
"Adding 50% dropout increases our accuracy a little above what we got with Naive Bayes; to 55%."
]
},
{
"cell_type": "code",
"execution_count": 119,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(data, MyLetterRNN(n_layers=2, dropout=0.5), loss_func=F.cross_entropy, metrics=[accuracy])"
]
},
{
"cell_type": "code",
"execution_count": 120,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:12 <p><table style='width:300px; margin-bottom:10px'>\n",
" <tr>\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <th>2.770683</th>\n",
" <th>2.746132</th>\n",
" <th>0.206250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <th>2.737838</th>\n",
" <th>2.597004</th>\n",
" <th>0.302083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <th>2.625530</th>\n",
" <th>2.253647</th>\n",
" <th>0.272917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <th>2.453825</th>\n",
" <th>2.077854</th>\n",
" <th>0.350000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <th>2.315888</th>\n",
" <th>1.959774</th>\n",
" <th>0.372917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <th>2.161286</th>\n",
" <th>3.241577</th>\n",
" <th>0.212500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <th>2.172942</th>\n",
" <th>1.954330</th>\n",
" <th>0.412500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <th>2.082301</th>\n",
" <th>1.910599</th>\n",
" <th>0.429167</th>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <th>1.960596</th>\n",
" <th>1.690655</th>\n",
" <th>0.493750</th>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <th>1.844915</th>\n",
" <th>1.704401</th>\n",
" <th>0.495833</th>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <th>1.749084</th>\n",
" <th>1.702470</th>\n",
" <th>0.462500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <th>1.649408</th>\n",
" <th>1.623091</th>\n",
" <th>0.491667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <th>1.574916</th>\n",
" <th>1.608889</th>\n",
" <th>0.506250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <th>1.477580</th>\n",
" <th>1.634902</th>\n",
" <th>0.529167</th>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <th>1.378102</th>\n",
" <th>1.681566</th>\n",
" <th>0.497917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <th>1.285141</th>\n",
" <th>1.700663</th>\n",
" <th>0.506250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <th>1.198917</th>\n",
" <th>1.739364</th>\n",
" <th>0.510417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <th>1.111936</th>\n",
" <th>1.807743</th>\n",
" <th>0.520833</th>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <th>1.044053</th>\n",
" <th>1.796086</th>\n",
" <th>0.495833</th>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <th>0.981998</th>\n",
" <th>1.776384</th>\n",
" <th>0.522917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>21</th>\n",
" <th>0.910434</th>\n",
" <th>1.867425</th>\n",
" <th>0.522917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>22</th>\n",
" <th>0.834056</th>\n",
" <th>1.867015</th>\n",
" <th>0.522917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>23</th>\n",
" <th>0.770921</th>\n",
" <th>1.860593</th>\n",
" <th>0.537500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>24</th>\n",
" <th>0.708352</th>\n",
" <th>1.918845</th>\n",
" <th>0.543750</th>\n",
" </tr>\n",
" <tr>\n",
" <th>25</th>\n",
" <th>0.653929</th>\n",
" <th>1.950722</th>\n",
" <th>0.522917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>26</th>\n",
" <th>0.603225</th>\n",
" <th>1.974899</th>\n",
" <th>0.531250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <th>0.564611</th>\n",
" <th>1.989453</th>\n",
" <th>0.535417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>28</th>\n",
" <th>0.534487</th>\n",
" <th>2.007739</th>\n",
" <th>0.539583</th>\n",
" </tr>\n",
" <tr>\n",
" <th>29</th>\n",
" <th>0.512126</th>\n",
" <th>2.006823</th>\n",
" <th>0.545833</th>\n",
" </tr>\n",
" <tr>\n",
" <th>30</th>\n",
" <th>0.493982</th>\n",
" <th>2.006345</th>\n",
" <th>0.547917</th>\n",
" </tr>\n",
"</table>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(30, max_lr=3e-2)"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.recorder.plot_losses()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Changing the dimension of hidden layers"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Using our default of 128 gets 54%"
]
},
{
"cell_type": "code",
"execution_count": 123,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:06 <p><table style='width:300px; margin-bottom:10px'>\n",
" <tr>\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <th>2.769207</th>\n",
" <th>2.736875</th>\n",
" <th>0.191667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <th>2.700270</th>\n",
" <th>2.432153</th>\n",
" <th>0.291667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <th>2.527224</th>\n",
" <th>2.165822</th>\n",
" <th>0.352083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <th>2.411032</th>\n",
" <th>2.005153</th>\n",
" <th>0.375000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <th>2.233243</th>\n",
" <th>1.797880</th>\n",
" <th>0.427083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <th>2.073015</th>\n",
" <th>1.878045</th>\n",
" <th>0.414583</th>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <th>1.945470</th>\n",
" <th>1.800441</th>\n",
" <th>0.425000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <th>1.808628</th>\n",
" <th>1.668809</th>\n",
" <th>0.466667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <th>1.670717</th>\n",
" <th>1.669584</th>\n",
" <th>0.456250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <th>1.536347</th>\n",
" <th>1.619124</th>\n",
" <th>0.514583</th>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <th>1.387313</th>\n",
" <th>1.561667</th>\n",
" <th>0.512500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <th>1.237852</th>\n",
" <th>1.506838</th>\n",
" <th>0.543750</th>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <th>1.109924</th>\n",
" <th>1.549464</th>\n",
" <th>0.541667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <th>1.008990</th>\n",
" <th>1.554212</th>\n",
" <th>0.543750</th>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <th>0.929697</th>\n",
" <th>1.554622</th>\n",
" <th>0.543750</th>\n",
" </tr>\n",
"</table>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(data, MyLetterRNN(n_layers=2, dropout=0.5, n_hidden=128), loss_func=F.cross_entropy, metrics=[accuracy])\n",
"learn.fit_one_cycle(15, max_lr=3e-2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Doubling to 256 doesn't change performance"
]
},
{
"cell_type": "code",
"execution_count": 124,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:06 <p><table style='width:300px; margin-bottom:10px'>\n",
" <tr>\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <th>2.778366</th>\n",
" <th>2.746171</th>\n",
" <th>0.233333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <th>2.692814</th>\n",
" <th>2.430204</th>\n",
" <th>0.266667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <th>2.557370</th>\n",
" <th>2.121804</th>\n",
" <th>0.335417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <th>2.400152</th>\n",
" <th>1.958430</th>\n",
" <th>0.383333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <th>2.266069</th>\n",
" <th>1.900910</th>\n",
" <th>0.377083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <th>2.139609</th>\n",
" <th>1.769911</th>\n",
" <th>0.437500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <th>1.997191</th>\n",
" <th>1.776130</th>\n",
" <th>0.443750</th>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <th>1.849745</th>\n",
" <th>1.595855</th>\n",
" <th>0.481250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <th>1.695713</th>\n",
" <th>1.665297</th>\n",
" <th>0.462500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <th>1.545231</th>\n",
" <th>1.611652</th>\n",
" <th>0.491667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <th>1.396923</th>\n",
" <th>1.523717</th>\n",
" <th>0.518750</th>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <th>1.269320</th>\n",
" <th>1.593659</th>\n",
" <th>0.531250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <th>1.159536</th>\n",
" <th>1.635850</th>\n",
" <th>0.518750</th>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <th>1.062906</th>\n",
" <th>1.653824</th>\n",
" <th>0.529167</th>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <th>0.984885</th>\n",
" <th>1.646750</th>\n",
" <th>0.531250</th>\n",
" </tr>\n",
"</table>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(data, MyLetterRNN(n_layers=2, dropout=0.5, n_hidden=128), loss_func=F.cross_entropy, metrics=[accuracy])\n",
"learn.fit_one_cycle(15, max_lr=3e-2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Halving to 64 definitely does; 128 does seem to be a sweet spot."
]
},
{
"cell_type": "code",
"execution_count": 125,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:06 <p><table style='width:300px; margin-bottom:10px'>\n",
" <tr>\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <th>2.788145</th>\n",
" <th>2.761199</th>\n",
" <th>0.089583</th>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <th>2.748204</th>\n",
" <th>2.577398</th>\n",
" <th>0.239583</th>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <th>2.619068</th>\n",
" <th>2.174743</th>\n",
" <th>0.316667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <th>2.437004</th>\n",
" <th>2.133431</th>\n",
" <th>0.362500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <th>2.297863</th>\n",
" <th>1.931483</th>\n",
" <th>0.372917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <th>2.146514</th>\n",
" <th>1.836011</th>\n",
" <th>0.410417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <th>2.041750</th>\n",
" <th>1.737024</th>\n",
" <th>0.437500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <th>1.896687</th>\n",
" <th>1.610484</th>\n",
" <th>0.477083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <th>1.764543</th>\n",
" <th>1.659021</th>\n",
" <th>0.462500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <th>1.635953</th>\n",
" <th>1.572796</th>\n",
" <th>0.493750</th>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <th>1.525941</th>\n",
" <th>1.614364</th>\n",
" <th>0.489583</th>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <th>1.413064</th>\n",
" <th>1.567258</th>\n",
" <th>0.497917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <th>1.319074</th>\n",
" <th>1.581343</th>\n",
" <th>0.483333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <th>1.237835</th>\n",
" <th>1.610361</th>\n",
" <th>0.502083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <th>1.175101</th>\n",
" <th>1.607811</th>\n",
" <th>0.502083</th>\n",
" </tr>\n",
"</table>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(data, MyLetterRNN(n_layers=2, dropout=0.5, n_hidden=64), loss_func=F.cross_entropy, metrics=[accuracy])\n",
"learn.fit_one_cycle(15, max_lr=3e-2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally 3 layers also gets a worse result."
]
},
{
"cell_type": "code",
"execution_count": 129,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:14 <p><table style='width:300px; margin-bottom:10px'>\n",
" <tr>\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <th>2.777945</th>\n",
" <th>2.767582</th>\n",
" <th>0.108333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <th>2.763745</th>\n",
" <th>2.700912</th>\n",
" <th>0.177083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <th>2.724584</th>\n",
" <th>2.610098</th>\n",
" <th>0.170833</th>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <th>2.610290</th>\n",
" <th>2.289089</th>\n",
" <th>0.293750</th>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <th>2.476536</th>\n",
" <th>2.088696</th>\n",
" <th>0.347917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <th>2.339663</th>\n",
" <th>1.914408</th>\n",
" <th>0.404167</th>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <th>2.205508</th>\n",
" <th>1.885391</th>\n",
" <th>0.425000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <th>2.120229</th>\n",
" <th>1.957143</th>\n",
" <th>0.341667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <th>2.053982</th>\n",
" <th>1.699707</th>\n",
" <th>0.397917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <th>1.987532</th>\n",
" <th>1.688885</th>\n",
" <th>0.441667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <th>1.908170</th>\n",
" <th>1.695390</th>\n",
" <th>0.462500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <th>1.848755</th>\n",
" <th>1.654914</th>\n",
" <th>0.456250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <th>1.779561</th>\n",
" <th>1.624753</th>\n",
" <th>0.460417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <th>1.697022</th>\n",
" <th>1.597392</th>\n",
" <th>0.475000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <th>1.646311</th>\n",
" <th>1.599763</th>\n",
" <th>0.475000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <th>1.582875</th>\n",
" <th>1.559585</th>\n",
" <th>0.477083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <th>1.531765</th>\n",
" <th>1.559109</th>\n",
" <th>0.475000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <th>1.491278</th>\n",
" <th>1.601305</th>\n",
" <th>0.462500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <th>1.446864</th>\n",
" <th>1.503486</th>\n",
" <th>0.483333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <th>1.387920</th>\n",
" <th>1.531969</th>\n",
" <th>0.508333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>21</th>\n",
" <th>1.325619</th>\n",
" <th>1.495371</th>\n",
" <th>0.520833</th>\n",
" </tr>\n",
" <tr>\n",
" <th>22</th>\n",
" <th>1.257018</th>\n",
" <th>1.581387</th>\n",
" <th>0.531250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>23</th>\n",
" <th>1.193253</th>\n",
" <th>1.517283</th>\n",
" <th>0.537500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>24</th>\n",
" <th>1.138225</th>\n",
" <th>1.559087</th>\n",
" <th>0.529167</th>\n",
" </tr>\n",
" <tr>\n",
" <th>25</th>\n",
" <th>1.086946</th>\n",
" <th>1.572238</th>\n",
" <th>0.539583</th>\n",
" </tr>\n",
" <tr>\n",
" <th>26</th>\n",
" <th>1.040665</th>\n",
" <th>1.561826</th>\n",
" <th>0.525000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <th>1.001137</th>\n",
" <th>1.584307</th>\n",
" <th>0.520833</th>\n",
" </tr>\n",
" <tr>\n",
" <th>28</th>\n",
" <th>0.970200</th>\n",
" <th>1.590697</th>\n",
" <th>0.516667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>29</th>\n",
" <th>0.947225</th>\n",
" <th>1.590535</th>\n",
" <th>0.516667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>30</th>\n",
" <th>0.927288</th>\n",
" <th>1.590184</th>\n",
" <th>0.516667</th>\n",
" </tr>\n",
"</table>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(data, MyLetterRNN(n_layers=3, dropout=0.5), loss_func=F.cross_entropy, metrics=[accuracy])\n",
"learn.fit_one_cycle(30, max_lr=3e-2)"
]
},
{
"cell_type": "code",
"execution_count": 129,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:14 <p><table style='width:300px; margin-bottom:10px'>\n",
" <tr>\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <th>2.777945</th>\n",
" <th>2.767582</th>\n",
" <th>0.108333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <th>2.763745</th>\n",
" <th>2.700912</th>\n",
" <th>0.177083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <th>2.724584</th>\n",
" <th>2.610098</th>\n",
" <th>0.170833</th>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <th>2.610290</th>\n",
" <th>2.289089</th>\n",
" <th>0.293750</th>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <th>2.476536</th>\n",
" <th>2.088696</th>\n",
" <th>0.347917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <th>2.339663</th>\n",
" <th>1.914408</th>\n",
" <th>0.404167</th>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <th>2.205508</th>\n",
" <th>1.885391</th>\n",
" <th>0.425000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <th>2.120229</th>\n",
" <th>1.957143</th>\n",
" <th>0.341667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <th>2.053982</th>\n",
" <th>1.699707</th>\n",
" <th>0.397917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <th>1.987532</th>\n",
" <th>1.688885</th>\n",
" <th>0.441667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <th>1.908170</th>\n",
" <th>1.695390</th>\n",
" <th>0.462500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <th>1.848755</th>\n",
" <th>1.654914</th>\n",
" <th>0.456250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <th>1.779561</th>\n",
" <th>1.624753</th>\n",
" <th>0.460417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <th>1.697022</th>\n",
" <th>1.597392</th>\n",
" <th>0.475000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <th>1.646311</th>\n",
" <th>1.599763</th>\n",
" <th>0.475000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <th>1.582875</th>\n",
" <th>1.559585</th>\n",
" <th>0.477083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <th>1.531765</th>\n",
" <th>1.559109</th>\n",
" <th>0.475000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <th>1.491278</th>\n",
" <th>1.601305</th>\n",
" <th>0.462500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <th>1.446864</th>\n",
" <th>1.503486</th>\n",
" <th>0.483333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <th>1.387920</th>\n",
" <th>1.531969</th>\n",
" <th>0.508333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>21</th>\n",
" <th>1.325619</th>\n",
" <th>1.495371</th>\n",
" <th>0.520833</th>\n",
" </tr>\n",
" <tr>\n",
" <th>22</th>\n",
" <th>1.257018</th>\n",
" <th>1.581387</th>\n",
" <th>0.531250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>23</th>\n",
" <th>1.193253</th>\n",
" <th>1.517283</th>\n",
" <th>0.537500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>24</th>\n",
" <th>1.138225</th>\n",
" <th>1.559087</th>\n",
" <th>0.529167</th>\n",
" </tr>\n",
" <tr>\n",
" <th>25</th>\n",
" <th>1.086946</th>\n",
" <th>1.572238</th>\n",
" <th>0.539583</th>\n",
" </tr>\n",
" <tr>\n",
" <th>26</th>\n",
" <th>1.040665</th>\n",
" <th>1.561826</th>\n",
" <th>0.525000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <th>1.001137</th>\n",
" <th>1.584307</th>\n",
" <th>0.520833</th>\n",
" </tr>\n",
" <tr>\n",
" <th>28</th>\n",
" <th>0.970200</th>\n",
" <th>1.590697</th>\n",
" <th>0.516667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>29</th>\n",
" <th>0.947225</th>\n",
" <th>1.590535</th>\n",
" <th>0.516667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>30</th>\n",
" <th>0.927288</th>\n",
" <th>1.590184</th>\n",
" <th>0.516667</th>\n",
" </tr>\n",
"</table>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(data, MyLetterRNN(n_layers=3, dropout=0.5), loss_func=F.cross_entropy, metrics=[accuracy])\n",
"learn.fit_one_cycle(30, max_lr=3e-2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## RNN From Scratch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's build our own RNN; instead of one hot encoding we'll use a `nn.Embedding`."
]
},
{
"cell_type": "code",
"execution_count": 131,
"metadata": {},
"outputs": [],
"source": [
"data = (TextList\n",
" .from_df(df, \n",
" cols=[2], \n",
" processor=processors)\n",
" .split_by_idxs(train_idx=bal_idx, valid_idx=valid_idx)\n",
" .label_from_df(cols=0)\n",
" .databunch(bs=1024))"
]
},
{
"cell_type": "code",
"execution_count": 132,
"metadata": {},
"outputs": [],
"source": [
"valid_data_set = set(tuple(_[0].data) for _ in data.valid_ds)\n",
"for datum in data.train_ds:\n",
" assert tuple(datum[0].data) not in valid_data_set, datum"
]
},
{
"cell_type": "code",
"execution_count": 133,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([19, 512]), torch.Size([512]))"
]
},
"execution_count": 133,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x, y = next(iter(data.train_dl))\n",
"x.shape, y.shape"
]
},
{
"cell_type": "code",
"execution_count": 134,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"512"
]
},
"execution_count": 134,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x.shape[-1]"
]
},
{
"cell_type": "code",
"execution_count": 135,
"metadata": {},
"outputs": [],
"source": [
"class Model(nn.Module):\n",
" def __init__(self, n_input, n_hidden, n_output, bn=False):\n",
" super().__init__()\n",
" self.i_h = nn.Embedding(n_input,n_hidden)\n",
" self.bn = nn.BatchNorm1d(n_hidden) if bn else None\n",
" self.o_h = nn.Linear(n_hidden, n_output)\n",
" self.h_h = nn.Linear(n_hidden, n_hidden)\n",
" self.reset()\n",
" \n",
" def forward(self, x):\n",
" # I'm not quite sure why the batch size seems to change to 720 in validation...\n",
" if self.h.shape[0] != x.shape[1]:\n",
" self.reset(x.shape[1])\n",
" h = self.h\n",
" x = self.i_h(x)\n",
" for xi in x:\n",
" h += xi\n",
" h = self.h_h(h)\n",
" h = F.relu(h)\n",
" if self.bn:\n",
" h = self.bn(h)\n",
" self.h = h.detach()\n",
" o = self.o_h(h)\n",
" return o\n",
" \n",
" def reset(self, size=None):\n",
" size = size or 1\n",
" self.h = torch.zeros(size, n_hidden).cuda()"
]
},
{
"cell_type": "code",
"execution_count": 136,
"metadata": {},
"outputs": [],
"source": [
"model = Model(n_letters, n_hidden, n_output).cuda()"
]
},
{
"cell_type": "code",
"execution_count": 137,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(data, model, loss_func=F.cross_entropy, metrics=[accuracy])"
]
},
{
"cell_type": "code",
"execution_count": 138,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
]
}
],
"source": [
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": 139,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.recorder.plot()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This simple RNN seems to work *better* than the one we built using `nn.rnn`, and we're only using one layer and haven't implemented dropout.\n",
"\n",
"The big difference is that we're using an embedding layer instead of one-hot encoding. This gives us an extra bunch of parameters we can fit."
]
},
{
"cell_type": "code",
"execution_count": 140,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:08 <p><table style='width:300px; margin-bottom:10px'>\n",
" <tr>\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <th>2.743902</th>\n",
" <th>2.623438</th>\n",
" <th>0.239583</th>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <th>2.569348</th>\n",
" <th>2.255966</th>\n",
" <th>0.333333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <th>2.323031</th>\n",
" <th>1.929247</th>\n",
" <th>0.397917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <th>2.109766</th>\n",
" <th>1.880335</th>\n",
" <th>0.406250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <th>1.942279</th>\n",
" <th>1.689045</th>\n",
" <th>0.460417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <th>1.745093</th>\n",
" <th>1.724750</th>\n",
" <th>0.452083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <th>1.561630</th>\n",
" <th>1.589873</th>\n",
" <th>0.508333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <th>1.374822</th>\n",
" <th>1.634860</th>\n",
" <th>0.525000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <th>1.204314</th>\n",
" <th>1.628457</th>\n",
" <th>0.518750</th>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <th>1.037950</th>\n",
" <th>1.674552</th>\n",
" <th>0.552083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <th>0.891771</th>\n",
" <th>1.778297</th>\n",
" <th>0.552083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <th>0.772682</th>\n",
" <th>1.856418</th>\n",
" <th>0.541667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <th>0.666088</th>\n",
" <th>1.887347</th>\n",
" <th>0.581250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <th>0.572181</th>\n",
" <th>1.933961</th>\n",
" <th>0.545833</th>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <th>0.487959</th>\n",
" <th>2.034596</th>\n",
" <th>0.568750</th>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <th>0.415125</th>\n",
" <th>2.025807</th>\n",
" <th>0.556250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <th>0.356673</th>\n",
" <th>2.072717</th>\n",
" <th>0.558333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <th>0.310153</th>\n",
" <th>2.125560</th>\n",
" <th>0.560417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <th>0.271913</th>\n",
" <th>2.125447</th>\n",
" <th>0.560417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <th>0.244666</th>\n",
" <th>2.124661</th>\n",
" <th>0.558333</th>\n",
" </tr>\n",
"</table>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(20, 7e-3)"
]
},
{
"cell_type": "code",
"execution_count": 141,
"metadata": {},
"outputs": [],
"source": [
"learn.save('rnn-bal-1')"
]
},
{
"cell_type": "code",
"execution_count": 142,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['Arabic',\n",
" 'Chinese',\n",
" 'Czech',\n",
" 'Dutch',\n",
" 'English',\n",
" 'French',\n",
" 'German',\n",
" 'Greek',\n",
" 'Irish',\n",
" 'Italian',\n",
" 'Japanese',\n",
" 'Korean',\n",
" 'Polish',\n",
" 'Russian',\n",
" 'Spanish',\n",
" 'Vietnamese']"
]
},
"execution_count": 142,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.classes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's save the data classes; this will be useful if we want to make predictions."
]
},
{
"cell_type": "code",
"execution_count": 143,
"metadata": {},
"outputs": [],
"source": [
"with open('data.classes', 'wb') as f:\n",
" pickle.dump(data.classes, f)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's save the model data directly."
]
},
{
"cell_type": "code",
"execution_count": 144,
"metadata": {},
"outputs": [],
"source": [
"with open('models/rnn-bal-1.model', 'wb') as f:\n",
" pickle.dump(model.state_dict(), f)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And read it back in."
]
},
{
"cell_type": "code",
"execution_count": 145,
"metadata": {},
"outputs": [],
"source": [
"with open('models/rnn-bal-1.model', 'rb') as f:\n",
" state = pickle.load(f)\n",
" model.load_state_dict(state)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Batchnorm"
]
},
{
"cell_type": "code",
"execution_count": 146,
"metadata": {},
"outputs": [],
"source": [
"model = Model(n_letters, n_hidden, n_output, bn=True).cuda()"
]
},
{
"cell_type": "code",
"execution_count": 147,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(data, model, loss_func=F.cross_entropy, metrics=[accuracy])"
]
},
{
"cell_type": "code",
"execution_count": 148,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
]
}
],
"source": [
"learn.lr_find()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Batch norm makes the learning surface much smoother"
]
},
{
"cell_type": "code",
"execution_count": 149,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.recorder.plot()"
]
},
{
"cell_type": "code",
"execution_count": 150,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:09 <p><table style='width:300px; margin-bottom:10px'>\n",
" <tr>\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <th>2.507374</th>\n",
" <th>2.190547</th>\n",
" <th>0.354167</th>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <th>2.235302</th>\n",
" <th>1.959421</th>\n",
" <th>0.387500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <th>2.003024</th>\n",
" <th>1.930297</th>\n",
" <th>0.445833</th>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <th>1.824707</th>\n",
" <th>1.977895</th>\n",
" <th>0.454167</th>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <th>1.707330</th>\n",
" <th>2.065223</th>\n",
" <th>0.470833</th>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <th>1.562633</th>\n",
" <th>2.085219</th>\n",
" <th>0.481250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <th>1.450853</th>\n",
" <th>2.300199</th>\n",
" <th>0.481250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <th>1.306937</th>\n",
" <th>2.536774</th>\n",
" <th>0.475000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <th>1.191327</th>\n",
" <th>2.522710</th>\n",
" <th>0.483333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <th>1.058434</th>\n",
" <th>2.229035</th>\n",
" <th>0.533333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <th>0.928646</th>\n",
" <th>2.419605</th>\n",
" <th>0.533333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <th>0.800899</th>\n",
" <th>2.562785</th>\n",
" <th>0.514583</th>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <th>0.701665</th>\n",
" <th>2.489831</th>\n",
" <th>0.533333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <th>0.613110</th>\n",
" <th>2.690704</th>\n",
" <th>0.522917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <th>0.528683</th>\n",
" <th>2.681259</th>\n",
" <th>0.537500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <th>0.458904</th>\n",
" <th>2.672714</th>\n",
" <th>0.539583</th>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <th>0.393891</th>\n",
" <th>2.499606</th>\n",
" <th>0.541667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <th>0.341363</th>\n",
" <th>2.683054</th>\n",
" <th>0.545833</th>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <th>0.300728</th>\n",
" <th>2.611176</th>\n",
" <th>0.547917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <th>0.272150</th>\n",
" <th>2.567088</th>\n",
" <th>0.550000</th>\n",
" </tr>\n",
"</table>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(20, 3e-2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this case we actually get a very similar fit."
]
},
{
"cell_type": "code",
"execution_count": 151,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.recorder.plot_losses()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Adding a little regularisation using weight decay seems to help; we get a 59%.\n",
"\n",
"Maybe dropout could help more."
]
},
{
"cell_type": "code",
"execution_count": 152,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:09 <p><table style='width:300px; margin-bottom:10px'>\n",
" <tr>\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <th>2.683858</th>\n",
" <th>2.500692</th>\n",
" <th>0.277083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <th>2.455149</th>\n",
" <th>2.113877</th>\n",
" <th>0.352083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <th>2.241700</th>\n",
" <th>1.903266</th>\n",
" <th>0.406250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <th>2.017066</th>\n",
" <th>1.779788</th>\n",
" <th>0.454167</th>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <th>1.849892</th>\n",
" <th>1.776859</th>\n",
" <th>0.493750</th>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <th>1.691735</th>\n",
" <th>1.798783</th>\n",
" <th>0.479167</th>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <th>1.523797</th>\n",
" <th>1.714493</th>\n",
" <th>0.520833</th>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <th>1.337385</th>\n",
" <th>1.671328</th>\n",
" <th>0.541667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <th>1.208459</th>\n",
" <th>1.819937</th>\n",
" <th>0.535417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <th>1.074747</th>\n",
" <th>1.773041</th>\n",
" <th>0.535417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <th>0.935993</th>\n",
" <th>1.755780</th>\n",
" <th>0.568750</th>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <th>0.815807</th>\n",
" <th>1.749104</th>\n",
" <th>0.562500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <th>0.713215</th>\n",
" <th>1.812801</th>\n",
" <th>0.566667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <th>0.620390</th>\n",
" <th>1.839610</th>\n",
" <th>0.575000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <th>0.550001</th>\n",
" <th>1.766414</th>\n",
" <th>0.591667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <th>0.485581</th>\n",
" <th>1.801695</th>\n",
" <th>0.597917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <th>0.429470</th>\n",
" <th>1.961521</th>\n",
" <th>0.587500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <th>0.386420</th>\n",
" <th>1.903309</th>\n",
" <th>0.597917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <th>0.346661</th>\n",
" <th>1.830627</th>\n",
" <th>0.593750</th>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <th>0.319924</th>\n",
" <th>1.856341</th>\n",
" <th>0.593750</th>\n",
" </tr>\n",
"</table>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = Model(n_letters, n_hidden, n_output, bn=True).cuda()\n",
"learn = Learner(data, model, loss_func=F.cross_entropy, metrics=[accuracy])\n",
"learn.fit_one_cycle(20, 1e-2, wd=0.1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## fastai Builtin\n",
"How does fastai's built in learner compare?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"How long are the names?"
]
},
{
"cell_type": "code",
"execution_count": 153,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"count 16869.000000\n",
"mean 7.409509\n",
"std 2.050366\n",
"min 2.000000\n",
"25% 6.000000\n",
"50% 7.000000\n",
"75% 9.000000\n",
"max 18.000000\n",
"Name: ascii_name, dtype: float64"
]
},
"execution_count": 153,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.ascii_name.str.len().describe()"
]
},
{
"cell_type": "code",
"execution_count": 154,
"metadata": {},
"outputs": [],
"source": [
"learn = text_classifier_learner(data, bptt=30)"
]
},
{
"cell_type": "code",
"execution_count": 156,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
]
}
],
"source": [
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": 159,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.recorder.plot(skip_end=10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note we don't necessarily *expect* this to do great because the parameters are tuned to processing medium sized documents a word at a time.\n",
"\n",
"However it gets 67% way outperforms our RNN model without *any* parameter tuning."
]
},
{
"cell_type": "code",
"execution_count": 160,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 01:07 <p><table style='width:300px; margin-bottom:10px'>\n",
" <tr>\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <th>2.760123</th>\n",
" <th>2.778821</th>\n",
" <th>0.062500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <th>2.630543</th>\n",
" <th>2.745926</th>\n",
" <th>0.075000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <th>2.513042</th>\n",
" <th>2.573858</th>\n",
" <th>0.160417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <th>2.355199</th>\n",
" <th>2.121467</th>\n",
" <th>0.300000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <th>2.195587</th>\n",
" <th>1.810203</th>\n",
" <th>0.387500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <th>2.012244</th>\n",
" <th>1.566222</th>\n",
" <th>0.472917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <th>1.909551</th>\n",
" <th>1.693872</th>\n",
" <th>0.460417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <th>1.752974</th>\n",
" <th>1.615144</th>\n",
" <th>0.527083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <th>1.620254</th>\n",
" <th>1.322627</th>\n",
" <th>0.581250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <th>1.416614</th>\n",
" <th>1.251798</th>\n",
" <th>0.625000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <th>1.226631</th>\n",
" <th>1.297575</th>\n",
" <th>0.610417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <th>1.056211</th>\n",
" <th>1.230383</th>\n",
" <th>0.641667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <th>0.922565</th>\n",
" <th>1.201090</th>\n",
" <th>0.664583</th>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <th>0.791524</th>\n",
" <th>1.235106</th>\n",
" <th>0.656250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <th>0.657713</th>\n",
" <th>1.220895</th>\n",
" <th>0.683333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <th>0.552160</th>\n",
" <th>1.252036</th>\n",
" <th>0.666667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <th>0.460102</th>\n",
" <th>1.207947</th>\n",
" <th>0.666667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <th>0.411827</th>\n",
" <th>1.196497</th>\n",
" <th>0.670833</th>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <th>0.370949</th>\n",
" <th>1.197413</th>\n",
" <th>0.668750</th>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <th>0.330069</th>\n",
" <th>1.188975</th>\n",
" <th>0.677083</th>\n",
" </tr>\n",
"</table>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(20, max_lr=7e-3)"
]
},
{
"cell_type": "code",
"execution_count": 161,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.recorder.plot_losses()"
]
},
{
"cell_type": "code",
"execution_count": 167,
"metadata": {},
"outputs": [],
"source": [
"learn.save('fastai_bal')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Pretraining the Encoder\n",
"\n",
"From the IMDB example we know for word level data pretraining the encoder gives much better results (albeit on *much* bigger datasets). Let's see if it improves things here."
]
},
{
"cell_type": "code",
"execution_count": 168,
"metadata": {},
"outputs": [],
"source": [
"data_lm = (TextList\n",
" .from_df(df, cols=[2], processor=processors)\n",
" .random_split_by_pct(0.1)\n",
" .label_for_lm()\n",
" .databunch(bs=32))"
]
},
{
"cell_type": "code",
"execution_count": 169,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table> <col width='5%'> <col width='95%'> <tr>\n",
" <th>idx</th>\n",
" <th>text</th>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <th> z h i h a r e v i t c h xxbos s m o l a k xxbos n o s c h e n k o xxbos c r o w n xxbos t o k a e v xxbos o r i o l xxbos d j a n i b e</th>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <th> p i s k o t i n xxbos o ' c a l l a g h a n n xxbos e o g h a n xxbos e n o k i xxbos s h a n a u r i n xxbos c h k h a r t i s h v i l</th>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <th>j i m a xxbos t z a g o l o v xxbos l i c h m a n xxbos c o w l e y xxbos b a g d a s a r o f f xxbos w a t e r f i e l d xxbos n e l l i xxbos</th>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <th> t a l m u d xxbos m a r t z e n k o xxbos r i p l e y xxbos z a v o r i n xxbos g e i g e r xxbos v r a z e l xxbos r e y e r xxbos r o</th>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <th>o v s a e v xxbos g r e a v e s xxbos r e k u n xxbos y u z v i s h i n xxbos t c h e k m a s o v xxbos s o n e xxbos g r u s h e t s k y xxbos</th>\n",
" </tr>\n",
"</table>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data_lm.show_batch()"
]
},
{
"cell_type": "code",
"execution_count": 175,
"metadata": {},
"outputs": [],
"source": [
"learn = language_model_learner(data_lm, drop_mult=0.5)"
]
},
{
"cell_type": "code",
"execution_count": 171,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
]
}
],
"source": [
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": 173,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.recorder.plot(skip_end=10)"
]
},
{
"cell_type": "code",
"execution_count": 176,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:34 <p><table style='width:300px; margin-bottom:10px'>\n",
" <tr>\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <th>2.655349</th>\n",
" <th>2.249993</th>\n",
" <th>0.325342</th>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <th>2.187761</th>\n",
" <th>1.975275</th>\n",
" <th>0.401822</th>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <th>2.007302</th>\n",
" <th>1.886618</th>\n",
" <th>0.426303</th>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <th>1.868937</th>\n",
" <th>1.847373</th>\n",
" <th>0.435982</th>\n",
" </tr>\n",
"</table>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(4, max_lr=1e-2)"
]
},
{
"cell_type": "code",
"execution_count": 177,
"metadata": {},
"outputs": [],
"source": [
"learn.save('letter_lang')\n",
"learn.save_encoder('letter_enc')"
]
},
{
"cell_type": "code",
"execution_count": 178,
"metadata": {},
"outputs": [],
"source": [
"TEXT = \"ho\"\n",
"N_WORDS = 4\n",
"N_SENTENCES = 5"
]
},
{
"cell_type": "code",
"execution_count": 179,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ho u s i m\n",
"ho n o b a\n",
"ho n e r e\n",
"ho n a r a\n",
"ho v a b e\n"
]
}
],
"source": [
"print(\"\\n\".join(learn.predict(TEXT, N_WORDS, temperature=0.75) for _ in range(N_SENTENCES)))"
]
},
{
"cell_type": "code",
"execution_count": 180,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tr i n o r\n",
"tr e r o e\n",
"tr e n e n\n",
"tr i r i s\n",
"tr u p u t\n"
]
}
],
"source": [
"TEXT = \"tr\"\n",
"print(\"\\n\".join(learn.predict(TEXT, N_WORDS, temperature=0.75) for _ in range(N_SENTENCES)))"
]
},
{
"cell_type": "code",
"execution_count": 181,
"metadata": {},
"outputs": [],
"source": [
"learn = text_classifier_learner(data, bptt=30)\n",
"learn.load_encoder('letter_enc')"
]
},
{
"cell_type": "code",
"execution_count": 182,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.lr_find()\n",
"learn.recorder.plot()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this case pretraining the encoder gives a *worse* result.\n",
"\n",
"Maybe it's because the language model was on the entire (unbalanced) dataset? Or wasn't well trained enough?"
]
},
{
"cell_type": "code",
"execution_count": 183,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:29 <p><table style='width:300px; margin-bottom:10px'>\n",
" <tr>\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <th>2.623937</th>\n",
" <th>2.702015</th>\n",
" <th>0.302083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <th>2.410758</th>\n",
" <th>2.412743</th>\n",
" <th>0.337500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <th>2.200685</th>\n",
" <th>1.927255</th>\n",
" <th>0.393750</th>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <th>2.046425</th>\n",
" <th>2.312107</th>\n",
" <th>0.316667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <th>1.865316</th>\n",
" <th>2.074282</th>\n",
" <th>0.393750</th>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <th>1.777486</th>\n",
" <th>2.281461</th>\n",
" <th>0.360417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <th>1.689011</th>\n",
" <th>2.557259</th>\n",
" <th>0.302083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <th>1.596411</th>\n",
" <th>2.346404</th>\n",
" <th>0.370833</th>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <th>1.566185</th>\n",
" <th>2.441514</th>\n",
" <th>0.341667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <th>1.473553</th>\n",
" <th>1.770901</th>\n",
" <th>0.429167</th>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <th>1.411398</th>\n",
" <th>1.677306</th>\n",
" <th>0.502083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <th>1.352370</th>\n",
" <th>1.966482</th>\n",
" <th>0.404167</th>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <th>1.323129</th>\n",
" <th>3.021722</th>\n",
" <th>0.310417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <th>1.271584</th>\n",
" <th>2.182109</th>\n",
" <th>0.389583</th>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <th>1.224382</th>\n",
" <th>1.864778</th>\n",
" <th>0.450000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <th>1.196533</th>\n",
" <th>1.896098</th>\n",
" <th>0.447917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <th>1.156429</th>\n",
" <th>1.960691</th>\n",
" <th>0.435417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <th>1.108229</th>\n",
" <th>1.840390</th>\n",
" <th>0.456250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <th>1.136375</th>\n",
" <th>2.003758</th>\n",
" <th>0.429167</th>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <th>1.125715</th>\n",
" <th>1.961390</th>\n",
" <th>0.443750</th>\n",
" </tr>\n",
"</table>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(20, max_lr=2e-2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## fastai: Hyperparameter Tuning"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With a bit of tuning we can make a much smaller model that trains faster and is almost as good"
]
},
{
"cell_type": "code",
"execution_count": 187,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:10 <p><table style='width:300px; margin-bottom:10px'>\n",
" <tr>\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <th>2.748441</th>\n",
" <th>2.779718</th>\n",
" <th>0.062500</th>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <th>2.560826</th>\n",
" <th>2.709974</th>\n",
" <th>0.122917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <th>2.413826</th>\n",
" <th>2.418473</th>\n",
" <th>0.325000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <th>2.248409</th>\n",
" <th>1.827642</th>\n",
" <th>0.397917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <th>2.131097</th>\n",
" <th>1.928447</th>\n",
" <th>0.385417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <th>2.005232</th>\n",
" <th>1.580826</th>\n",
" <th>0.510417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <th>1.831701</th>\n",
" <th>1.488690</th>\n",
" <th>0.510417</th>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <th>1.711539</th>\n",
" <th>1.291139</th>\n",
" <th>0.600000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <th>1.528764</th>\n",
" <th>1.403310</th>\n",
" <th>0.572917</th>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <th>1.353717</th>\n",
" <th>1.199514</th>\n",
" <th>0.627083</th>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <th>1.200141</th>\n",
" <th>1.201450</th>\n",
" <th>0.658333</th>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <th>1.080066</th>\n",
" <th>1.182234</th>\n",
" <th>0.641667</th>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <th>0.959931</th>\n",
" <th>1.155729</th>\n",
" <th>0.650000</th>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <th>0.873939</th>\n",
" <th>1.152237</th>\n",
" <th>0.656250</th>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <th>0.815623</th>\n",
" <th>1.167501</th>\n",
" <th>0.654167</th>\n",
" </tr>\n",
"</table>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = text_classifier_learner(data, bptt=30, emb_sz=200, nh=300, nl=2)\n",
"learn.fit_one_cycle(15, max_lr=1e-2, moms=(0.2, 0.1))"
]
},
{
"cell_type": "code",
"execution_count": 188,
"metadata": {},
"outputs": [],
"source": [
"learn.save('fastai_min')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Analysing the results"
]
},
{
"cell_type": "code",
"execution_count": 189,
"metadata": {},
"outputs": [],
"source": [
"learn.load('fastai_min')\n",
"None"
]
},
{
"cell_type": "code",
"execution_count": 190,
"metadata": {},
"outputs": [],
"source": [
"prob, target, losses = learn.get_preds(with_loss=True)\n",
"pred = np.array([data.classes[_] for _ in prob.argmax(dim=1)])\n",
"target = np.array([data.classes[_] for _ in target])"
]
},
{
"cell_type": "code",
"execution_count": 191,
"metadata": {},
"outputs": [],
"source": [
"x, y = list(learn.data.valid_dl)[0]\n",
"y = np.array([data.classes[_] for _ in y])"
]
},
{
"cell_type": "code",
"execution_count": 192,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(480, 480)"
]
},
"execution_count": 192,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(y), len(prob)"
]
},
{
"cell_type": "code",
"execution_count": 193,
"metadata": {},
"outputs": [],
"source": [
"names = np.array([''.join([vocab.itos[x] for x in l if x != 1][1:]) for l in zip(*x)])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I certainly think we *could* do better, but let's call it good enough."
]
},
{
"cell_type": "code",
"execution_count": 196,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"[('simonis', 'Greek', 'Dutch', tensor(6.7631)),\n",
" ('dam', 'Korean', 'Vietnamese', tensor(6.4027)),\n",
" ('jelen', 'Dutch', 'Polish', tensor(6.1876)),\n",
" ('cha', 'Vietnamese', 'Korean', tensor(6.1856)),\n",
" ('hayden', 'Dutch', 'Irish', tensor(6.1785)),\n",
" ('chmiel', 'French', 'Polish', tensor(5.8600)),\n",
" ('blanxart', 'English', 'Spanish', tensor(5.8351)),\n",
" ('chicken', 'Dutch', 'Czech', tensor(5.6187)),\n",
" ('attia', 'Spanish', 'Arabic', tensor(5.4960)),\n",
" ('ton', 'Korean', 'Vietnamese', tensor(5.4933))]"
]
},
"execution_count": 196,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loss_val, idx = losses.topk(10)\n",
"list(zip(names[idx], pred[idx], target[idx], loss_val))"
]
},
{
"cell_type": "code",
"execution_count": 203,
"metadata": {},
"outputs": [],
"source": [
"confuse = sklearn.metrics.confusion_matrix(target, pred, labels=data.classes)"
]
},
{
"cell_type": "code",
"execution_count": 204,
"metadata": {},
"outputs": [],
"source": [
"def most_confused(n):\n",
" top = []\n",
" for i, row in enumerate(confuse):\n",
" for j, cell in enumerate(row):\n",
" if i == j:\n",
" continue\n",
" if cell >= n:\n",
" top.append([data.classes[i],data.classes[j], cell])\n",
" return sorted(top, key=lambda x: x[2], reverse=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Most of the confusion is between similar language families:\n",
"- Vietnamese and Korean and Chinese\n",
"- Czech and Polish\n",
"- Spanish and Italian\n",
"\n",
"This is a good sign"
]
},
{
"cell_type": "code",
"execution_count": 205,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[['Vietnamese', 'Chinese', 13],\n",
" ['Korean', 'Chinese', 11],\n",
" ['Spanish', 'Italian', 8],\n",
" ['Polish', 'Czech', 6],\n",
" ['Chinese', 'Korean', 4],\n",
" ['Czech', 'German', 4],\n",
" ['English', 'French', 4],\n",
" ['English', 'German', 4],\n",
" ['German', 'English', 4],\n",
" ['Vietnamese', 'Korean', 4],\n",
" ['Arabic', 'German', 3],\n",
" ['Czech', 'Spanish', 3],\n",
" ['Dutch', 'English', 3],\n",
" ['English', 'Dutch', 3],\n",
" ['German', 'Dutch', 3],\n",
" ['Irish', 'Chinese', 3],\n",
" ['Korean', 'Vietnamese', 3],\n",
" ['Polish', 'German', 3],\n",
" ['Spanish', 'French', 3],\n",
" ['Vietnamese', 'Irish', 3]]"
]
},
"execution_count": 205,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"most_confused(3)"
]
},
{
"cell_type": "code",
"execution_count": 210,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Confusion matrix, without normalization\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x432 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(6,6))\n",
"plot_confusion_matrix(confuse, data.classes)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Predictions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's set up everything from scratch so we could set it up in an external app"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from fastai import *\n",
"from fastai.text import *"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from unidecode import unidecode\n",
"import string"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['Arabic',\n",
" 'Chinese',\n",
" 'Czech',\n",
" 'Dutch',\n",
" 'English',\n",
" 'French',\n",
" 'German',\n",
" 'Greek',\n",
" 'Irish',\n",
" 'Italian',\n",
" 'Japanese',\n",
" 'Korean',\n",
" 'Polish',\n",
" 'Russian',\n",
" 'Spanish',\n",
" 'Vietnamese']"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"with open('data.classes', 'rb') as f:\n",
" classes = pickle.load(f)\n",
"classes"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class LetterTokenizer(BaseTokenizer):\n",
" \"Character level tokenizer function.\"\n",
" def __init__(self, lang): pass\n",
" def tokenizer(self, t:str) -> List[str]:\n",
" t = unidecode(t).lower() ## Decode in tokenizer (ideally would be a separate preprocessor)\n",
" out = []\n",
" i = 0\n",
" while i < len(t):\n",
" if t[i:].startswith(BOS):\n",
" out.append(BOS)\n",
" i += len(BOS)\n",
" else:\n",
" out.append(t[i])\n",
" i += 1\n",
" return out\n",
" \n",
" def add_special_cases(self, toks:Collection[str]): pass"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"itos = [UNK, BOS] + list(string.ascii_lowercase + \" -'\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"vocab=Vocab(itos)\n",
"tokenizer=Tokenizer(LetterTokenizer, pre_rules=[], post_rules=[])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>cl</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td></td>\n",
" <td>Arabic</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td></td>\n",
" <td>Chinese</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td></td>\n",
" <td>Czech</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td></td>\n",
" <td>Dutch</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td></td>\n",
" <td>English</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td></td>\n",
" <td>French</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td></td>\n",
" <td>German</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td></td>\n",
" <td>Greek</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td></td>\n",
" <td>Irish</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td></td>\n",
" <td>Italian</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td></td>\n",
" <td>Japanese</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td></td>\n",
" <td>Korean</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td></td>\n",
" <td>Polish</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td></td>\n",
" <td>Russian</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td></td>\n",
" <td>Spanish</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td></td>\n",
" <td>Vietnamese</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" text cl\n",
"0 Arabic\n",
"1 Chinese\n",
"2 Czech\n",
"3 Dutch\n",
"4 English\n",
"5 French\n",
"6 German\n",
"7 Greek\n",
"8 Irish\n",
"9 Italian\n",
"10 Japanese\n",
"11 Korean\n",
"12 Polish\n",
"13 Russian\n",
"14 Spanish\n",
"15 Vietnamese"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"empty = pd.DataFrame({'text':'', 'cl':classes})\n",
"empty"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"processors = [TokenizeProcessor(tokenizer=tokenizer, mark_fields=False),\n",
" NumericalizeProcessor(vocab=vocab)]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"data = TextList.from_df(empty, processor=processors).no_split().label_from_df(cols='cl').databunch(bs=2)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"learn = text_classifier_learner(data, bptt=30, emb_sz=200, nh=300, nl=2)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"learn.load('fastai_min')\n",
"None"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check it's not in the training set"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"!grep -ir '^Wu' data/names"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Category Chinese,\n",
" tensor(1),\n",
" tensor([1.5628e-04, 8.2566e-01, 1.5067e-04, 7.8611e-04, 5.4162e-03, 6.0296e-06,\n",
" 1.6915e-03, 7.5540e-05, 1.2069e-03, 2.6427e-05, 2.3748e-03, 1.0973e-01,\n",
" 4.8861e-02, 1.4435e-04, 4.1366e-05, 3.6734e-03]))"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.predict('Wu') # Chinese"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"def predictions(name):\n",
" return sorted(zip(classes, (_.item() for _ in learn.predict(name)[2])), key=lambda x: x[1], reverse=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"How does it do in practice?"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('Polish', 0.9237067699432373),\n",
" ('Czech', 0.0443655289709568),\n",
" ('Vietnamese', 0.00837066862732172),\n",
" ('Spanish', 0.006321582943201065),\n",
" ('Chinese', 0.0038041360676288605)]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictions(\"Wojtyła\")[:5] # Polish"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('Czech', 0.866197407245636),\n",
" ('Polish', 0.10485668480396271),\n",
" ('Russian', 0.026793140918016434),\n",
" ('Korean', 0.00042316922917962074),\n",
" ('Japanese', 0.00041837719618342817)]"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictions(\"Dvořák\")[:5] # Czech"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('Italian', 0.8921791911125183),\n",
" ('Russian', 0.0403725765645504),\n",
" ('Japanese', 0.025237590074539185),\n",
" ('Spanish', 0.015365079045295715),\n",
" ('Arabic', 0.01298986654728651)]"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictions(\"Gaddafi\")[:5] # Arabic"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('Dutch', 0.35977253317832947),\n",
" ('German', 0.21167784929275513),\n",
" ('French', 0.18858234584331512),\n",
" ('English', 0.1542830914258957),\n",
" ('Irish', 0.02631647326052189)]"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictions('Goethe')[:5] # German"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Sometimes it does bad even if it's in the source data (it may not have ended up in training)"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"data/names/Korean.txt:Kim\r\n",
"data/names/Japanese.txt:Kimio\r\n",
"data/names/Japanese.txt:Kimiyama\r\n",
"data/names/Japanese.txt:Kimura\r\n",
"data/names/Vietnamese.txt:Pham\r\n",
"data/names/Vietnamese.txt:Kim\r\n",
"data/names/English.txt:Kimber\r\n",
"data/names/English.txt:Kimble\r\n",
"data/names/French.txt:Pascal\r\n"
]
}
],
"source": [
"!grep -Er 'Pascal|Pham' data/names"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('Spanish', 0.9103199243545532),\n",
" ('Italian', 0.07529251277446747),\n",
" ('Polish', 0.005148978438228369),\n",
" ('Greek', 0.0036756773479282856),\n",
" ('Czech', 0.0034040361642837524)]"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictions(\"Pascal\")[:5] # French"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('Dutch', 0.2937869727611542),\n",
" ('Vietnamese', 0.13536876440048218),\n",
" ('English', 0.11861108988523483),\n",
" ('French', 0.07520488649606705),\n",
" ('Irish', 0.0743907243013382)]"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictions(\"Pham\")[:5] # Vietnamese"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"But sometimes it gets it right"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"!grep -ir '^Meijer' data/names"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('Dutch', 0.9885940551757812),\n",
" ('German', 0.008121304214000702),\n",
" ('Czech', 0.0013009845279157162),\n",
" ('Korean', 0.00039360582013614476),\n",
" ('English', 0.0003091120161116123)]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictions(\"Meijer\")[:5] # Dutch"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('Polish', 0.9900423288345337),\n",
" ('Czech', 0.004603931214660406),\n",
" ('Chinese', 0.002563303103670478),\n",
" ('Korean', 0.0009222071967087686),\n",
" ('Dutch', 0.0005194866680540144)]"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictions('Wójcik')[:5] # Polish"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This model is not bad; but definitely sub-human."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"What does it think about our ambiguous \"Michel\"?"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('Czech', 0.6636908054351807),\n",
" ('German', 0.11761205643415451),\n",
" ('English', 0.061124321073293686),\n",
" ('Irish', 0.04792550206184387),\n",
" ('Polish', 0.027553826570510864),\n",
" ('French', 0.026092179119586945),\n",
" ('Dutch', 0.020014718174934387)]"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictions('Michel')[:7]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Predicting from a pretrained custom model"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"class Model(nn.Module):\n",
" def __init__(self, n_input, n_hidden, n_output, bn=False, use_cuda=False):\n",
" super().__init__()\n",
" self.i_h = nn.Embedding(n_input,n_hidden)\n",
" self.bn = nn.BatchNorm1d(n_hidden) if bn else None\n",
" self.o_h = nn.Linear(n_hidden, n_output)\n",
" self.h_h = nn.Linear(n_hidden, n_hidden)\n",
" self.use_cuda = use_cuda\n",
" self.reset()\n",
" \n",
" def forward(self, x):\n",
" # I'm not quite sure why the batch size seems to change to 720 in validation...\n",
" if self.h.shape[0] != x.shape[1]:\n",
" self.reset(x.shape[1])\n",
" h = self.h\n",
" x = self.i_h(x)\n",
" for xi in x:\n",
" h += xi\n",
" h = self.h_h(h)\n",
" h = F.relu(h)\n",
" if self.bn:\n",
" h = self.bn(h)\n",
" self.h = h.detach()\n",
" o = self.o_h(h)\n",
" return o\n",
" \n",
" def reset(self, size=None):\n",
" size = size or 1\n",
" self.h = torch.zeros(size, n_hidden)\n",
" if self.use_cuda:\n",
" self.h = self.h.cuda()"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"n_letters = len(vocab.itos)\n",
"n_hidden = 128\n",
"n_output = len(classes)\n",
"model = Model(n_letters, n_hidden, n_output)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"with open('models/rnn-bal-1.model', 'rb') as f:\n",
" state = pickle.load(f)\n",
" model.load_state_dict(state)\n",
"model = model.cpu()\n",
"model = model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"for param in model.parameters():\n",
" param.requires_grad = False"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"name = 'Wójcik' # Polish"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'xxbosWojcik'"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"decode = BOS + unidecode(name)\n",
"decode"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['xxbos', 'w', 'o', 'j', 'c', 'i', 'k']"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokens = tokenizer.process_all([decode])[0]\n",
"tokens"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[1, 24, 16, 11, 4, 10, 12]"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nums = vocab.numericalize(tokens)\n",
"nums"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 1],\n",
" [24],\n",
" [16],\n",
" [11],\n",
" [ 4],\n",
" [10],\n",
" [12]])"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = torch.tensor([nums]).transpose(1,0)\n",
"x"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ -8.0967, -6.4112, 7.0235, 4.0332, 1.4341, -15.9867, -0.3718,\n",
" -9.9789, -17.9649, -7.5314, -2.7700, -0.1450, -3.1552, 1.9744,\n",
" -14.5163, -15.3465]])"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result = model(x).detach()\n",
"result"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([2.5545e-07, 1.3781e-06, 9.4171e-01, 4.7340e-02, 3.5193e-03, 9.5652e-11,\n",
" 5.7831e-04, 3.8891e-08, 1.3231e-11, 4.4956e-07, 5.2559e-05, 7.2556e-04,\n",
" 3.5759e-05, 6.0412e-03, 4.1617e-10, 1.8145e-10])"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"probs = F.softmax(result[0], dim=0)\n",
"probs"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Czech: Probability 94.17%\n",
"Dutch: Probability 4.73%\n",
"Russian: Probability 0.60%\n"
]
}
],
"source": [
"for prob, idx in zip(*probs.topk(3)):\n",
" print(f'{classes[idx]}: Probability {prob:0.2%}')"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"def get_probs(name):\n",
" decode = BOS + unidecode(name)\n",
" tokens = tokenizer.process_all([decode])[0]\n",
" nums = vocab.numericalize(tokens)\n",
" x = torch.tensor([nums]).transpose(1,0)\n",
" model.reset()\n",
" result = model(x).detach()\n",
" probs = F.softmax(result[0], dim=0)\n",
" return probs"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"def print_top_probs(name, n=3):\n",
" probs = get_probs(name)\n",
" for prob, idx in zip(*probs.topk(n)):\n",
" print(f'{classes[idx]}: Probability {prob:0.2%}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In reality the model doesn't do great by human standards"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Irish: Probability 72.69%\n",
"English: Probability 16.42%\n",
"Japanese: Probability 3.84%\n"
]
}
],
"source": [
"print_top_probs('Goethe') # German"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"German: Probability 51.32%\n",
"English: Probability 34.13%\n",
"Chinese: Probability 9.00%\n"
]
}
],
"source": [
"print_top_probs('Jinping') # Chinese"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Korean: Probability 41.02%\n",
"Russian: Probability 22.39%\n",
"Dutch: Probability 16.08%\n"
]
}
],
"source": [
"print_top_probs('Kim') # Korean"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Korean: Probability 41.59%\n",
"Chinese: Probability 34.24%\n",
"Vietnamese: Probability 14.89%\n"
]
}
],
"source": [
"print_top_probs('Đặng') # Vietnamese"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Arabic: Probability 91.50%\n",
"Russian: Probability 4.34%\n",
"Czech: Probability 2.74%\n"
]
}
],
"source": [
"print_top_probs('Zahir') # Arabic"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It's also possible to use `learn.load` to load in the model, if you make some fake data.\n",
"\n",
"We need at least 2 rows or it will complain."
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td></td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td></td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" 0\n",
"0 \n",
"1 "
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"empty = pd.DataFrame([[' ']]*2)\n",
"empty"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
"processors = [TokenizeProcessor(tokenizer=tokenizer, mark_fields=False),\n",
" NumericalizeProcessor(vocab=vocab)]"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
"data = TextList.from_df(empty, processor=processors).no_split().label_const().databunch(bs=2)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"model = Model(n_letters, n_hidden, n_output)"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(data, model)"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"learn = learn.load('rnn-bal-1')"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"learn.model = learn.model.eval().cpu()\n",
"for param in learn.model.parameters():\n",
" param.requires_grad = False"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"x, _ = data.one_item('Dvořák') # Czech"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:2: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" \n"
]
},
{
"data": {
"text/plain": [
"tensor([[1.2497e-05, 3.4860e-06, 5.5647e-01, 1.3278e-04, 1.5334e-03, 4.7975e-04,\n",
" 7.6774e-05, 6.1160e-06, 4.5827e-08, 2.0946e-05, 1.2422e-04, 5.2584e-06,\n",
" 4.1456e-01, 2.5337e-02, 1.2437e-03, 2.2771e-11]])"
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.model.reset()\n",
"probs = F.softmax(learn.model(x.cpu()))\n",
"probs"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Czech: Probability 55.65%\n",
"Polish: Probability 41.46%\n",
"Russian: Probability 2.53%\n"
]
}
],
"source": [
"for prob, idx in zip(*probs[0].topk(3)):\n",
" print(f'{classes[idx]}: Probability {prob:0.2%}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Using the model to find similar names\n",
"\n",
"The idea is to dig into the representation in the 50 dimensional activation and use this to compare names.\n",
"\n",
"Two names are similar if they are close together in this embedding space.\n",
"It's not totally obvious that the RMS distance is appropriate for this, but it's what we'll use."
]
},
{
"cell_type": "code",
"execution_count": 113,
"metadata": {},
"outputs": [],
"source": [
"from fastai.callbacks.hooks import *"
]
},
{
"cell_type": "code",
"execution_count": 215,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cl</th>\n",
" <th>name</th>\n",
" <th>ascii_name</th>\n",
" <th>valid</th>\n",
" <th>bal</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Korean</td>\n",
" <td>Ahn</td>\n",
" <td>ahn</td>\n",
" <td>False</td>\n",
" <td>13</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Korean</td>\n",
" <td>Baik</td>\n",
" <td>baik</td>\n",
" <td>True</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Korean</td>\n",
" <td>Bang</td>\n",
" <td>bang</td>\n",
" <td>False</td>\n",
" <td>13</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Korean</td>\n",
" <td>Byon</td>\n",
" <td>byon</td>\n",
" <td>False</td>\n",
" <td>15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Korean</td>\n",
" <td>Cha</td>\n",
" <td>cha</td>\n",
" <td>True</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" cl name ascii_name valid bal\n",
"0 Korean Ahn ahn False 13\n",
"1 Korean Baik baik True 0\n",
"2 Korean Bang bang False 13\n",
"3 Korean Byon byon False 15\n",
"4 Korean Cha cha True 0"
]
},
"execution_count": 215,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv('names_clean.csv')\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 216,
"metadata": {},
"outputs": [],
"source": [
"data = TextList.from_df(df, cols='ascii_name', processor=processors).no_split().label_from_df('cl').databunch(bs=1024)"
]
},
{
"cell_type": "code",
"execution_count": 217,
"metadata": {},
"outputs": [],
"source": [
"# model = Model(n_letters, n_hidden, n_output).cuda()\n",
"# learn = Learner(data, model)\n",
"# learn = learn.load('rnn-bal-1')\n",
"learn = text_classifier_learner(data, bptt=30, emb_sz=200, nh=300, nl=2)\n",
"learn.load('fastai_min')\n",
"None"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's look at the structure of our model"
]
},
{
"cell_type": "code",
"execution_count": 218,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('0', MultiBatchRNNCore(\n",
" (encoder): Embedding(31, 200, padding_idx=1)\n",
" (encoder_dp): EmbeddingDropout(\n",
" (emb): Embedding(31, 200, padding_idx=1)\n",
" )\n",
" (rnns): ModuleList(\n",
" (0): WeightDropout(\n",
" (module): LSTM(200, 300)\n",
" )\n",
" (1): WeightDropout(\n",
" (module): LSTM(300, 200)\n",
" )\n",
" )\n",
" (input_dp): RNNDropout()\n",
" (hidden_dps): ModuleList(\n",
" (0): RNNDropout()\n",
" (1): RNNDropout()\n",
" )\n",
" )), ('1', PoolingLinearClassifier(\n",
" (layers): Sequential(\n",
" (0): BatchNorm1d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (1): Dropout(p=0.4)\n",
" (2): Linear(in_features=600, out_features=50, bias=True)\n",
" (3): ReLU(inplace)\n",
" (4): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (5): Dropout(p=0.1)\n",
" (6): Linear(in_features=50, out_features=16, bias=True)\n",
" )\n",
" ))]"
]
},
"execution_count": 218,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(learn.model.named_children())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's capture the output of the 50 dimensional embedding near the end"
]
},
{
"cell_type": "code",
"execution_count": 221,
"metadata": {},
"outputs": [],
"source": [
"layer = 17"
]
},
{
"cell_type": "code",
"execution_count": 222,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Linear(in_features=600, out_features=50, bias=True)"
]
},
"execution_count": 222,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(learn.model.modules())[layer]"
]
},
{
"cell_type": "code",
"execution_count": 223,
"metadata": {},
"outputs": [],
"source": [
"def embed(x):\n",
" #with hook_output(list(learn.model.children())[-1]) as hook_a: \n",
" with hook_output(list(learn.model.modules())[layer]) as hook_a:\n",
" preds = learn.predict(x)\n",
" return hook_a.stored[0]"
]
},
{
"cell_type": "code",
"execution_count": 224,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 32.2 s, sys: 2.22 s, total: 34.4 s\n",
"Wall time: 34.4 s\n"
]
}
],
"source": [
"%time df = df.assign(embed = df.name.apply(embed))"
]
},
{
"cell_type": "code",
"execution_count": 226,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cl</th>\n",
" <th>name</th>\n",
" <th>ascii_name</th>\n",
" <th>valid</th>\n",
" <th>bal</th>\n",
" <th>embed</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Korean</td>\n",
" <td>Ahn</td>\n",
" <td>ahn</td>\n",
" <td>False</td>\n",
" <td>13</td>\n",
" <td>[tensor(0.7824, device='cuda:0'), tensor(2.431...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Korean</td>\n",
" <td>Baik</td>\n",
" <td>baik</td>\n",
" <td>True</td>\n",
" <td>0</td>\n",
" <td>[tensor(0., device='cuda:0'), tensor(0., devic...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Korean</td>\n",
" <td>Bang</td>\n",
" <td>bang</td>\n",
" <td>False</td>\n",
" <td>13</td>\n",
" <td>[tensor(0., device='cuda:0'), tensor(1.9753, d...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Korean</td>\n",
" <td>Byon</td>\n",
" <td>byon</td>\n",
" <td>False</td>\n",
" <td>15</td>\n",
" <td>[tensor(0., device='cuda:0'), tensor(4.9263, d...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Korean</td>\n",
" <td>Cha</td>\n",
" <td>cha</td>\n",
" <td>True</td>\n",
" <td>0</td>\n",
" <td>[tensor(0., device='cuda:0'), tensor(0., devic...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" cl name ascii_name valid bal \\\n",
"0 Korean Ahn ahn False 13 \n",
"1 Korean Baik baik True 0 \n",
"2 Korean Bang bang False 13 \n",
"3 Korean Byon byon False 15 \n",
"4 Korean Cha cha True 0 \n",
"\n",
" embed \n",
"0 [tensor(0.7824, device='cuda:0'), tensor(2.431... \n",
"1 [tensor(0., device='cuda:0'), tensor(0., devic... \n",
"2 [tensor(0., device='cuda:0'), tensor(1.9753, d... \n",
"3 [tensor(0., device='cuda:0'), tensor(4.9263, d... \n",
"4 [tensor(0., device='cuda:0'), tensor(0., devic... "
]
},
"execution_count": 226,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 227,
"metadata": {},
"outputs": [],
"source": [
"def closest(name, n=10):\n",
" e = embed(name)\n",
" dist = [d(e, _) for _ in df.embed]\n",
" for idx in np.argsort(dist)[:10]:\n",
" print(f'{df.name.iloc[idx.item()]} ({df.cl.iloc[idx.item()]}): {dist[idx]}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It's not immediately clear in what sense these are similar; but it doesn't seem random to me"
]
},
{
"cell_type": "code",
"execution_count": 229,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Ahn (Korean): 0.0\n",
"Hor (Chinese): 74.33894348144531\n",
"Gul (Russian): 78.02725219726562\n",
"Noh (Korean): 88.23501586914062\n",
"Hon (Russian): 89.40960693359375\n",
"Ryu (Korean): 90.15037536621094\n",
"Byon (Korean): 92.82701110839844\n",
"Jermy (English): 93.4688720703125\n",
"Bishop (English): 96.69548034667969\n",
"Heron (English): 97.36801147460938\n",
"CPU times: user 1.46 s, sys: 188 ms, total: 1.65 s\n",
"Wall time: 1.65 s\n"
]
}
],
"source": [
"%time closest('Ahn')"
]
},
{
"cell_type": "code",
"execution_count": 232,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Turner (English): 34.640953063964844\n",
"Raizer (Russian): 41.138641357421875\n",
"Reuter (German): 45.08790588378906\n",
"Gunter (English): 46.127220153808594\n",
"Mendel (German): 48.00371551513672\n",
"Render (English): 48.202239990234375\n",
"Raeburn (English): 52.79540252685547\n",
"Rosenberger (German): 53.13862228393555\n",
"Rebinder (Russian): 53.502662658691406\n",
"Rosser (English): 54.7049446105957\n"
]
}
],
"source": [
"closest('Ruder')"
]
},
{
"cell_type": "code",
"execution_count": 233,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Oelberg (German): 28.07500457763672\n",
"Mordberg (Russian): 29.14569854736328\n",
"Gramberg (Russian): 30.82413101196289\n",
"Engman (Russian): 31.429853439331055\n",
"Burman (English): 33.723426818847656\n",
"Bumgarner (German): 33.76372528076172\n",
"Egger (German): 35.34405517578125\n",
"Großer (German): 35.70815658569336\n",
"Ranger (English): 35.80017852783203\n",
"Grainger (English): 36.01886749267578\n"
]
}
],
"source": [
"closest('Gugger')"
]
},
{
"cell_type": "code",
"execution_count": 234,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Manus (Irish): 49.601768493652344\n",
"Jemaitis (Russian): 56.250274658203125\n",
"Horos (Russian): 63.35132598876953\n",
"Klimes (Czech): 73.13825225830078\n",
"Bertsimas (Greek): 73.65045166015625\n",
"Tsogas (Greek): 79.87809753417969\n",
"Simonis (Dutch): 85.69441223144531\n",
"Honjas (Greek): 86.7238998413086\n",
"Mihelyus (Russian): 87.06715393066406\n",
"Grotus (Russian): 88.79036712646484\n"
]
}
],
"source": [
"closest('Thomas')"
]
},
{
"cell_type": "code",
"execution_count": 201,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"East (English): 59.85662841796875\n",
"Gammer (English): 60.93962097167969\n",
"Gale (English): 61.32642364501953\n",
"Gass (German): 65.68402099609375\n",
"Abrams (English): 68.60626983642578\n",
"Groer (Russian): 72.51294708251953\n",
"Bannister (English): 72.60062408447266\n",
"Glencross (English): 73.78807830810547\n",
"Moss (English): 74.06442260742188\n",
"Gander (English): 75.61122131347656\n"
]
}
],
"source": [
"closest('Ross')"
]
},
{
"cell_type": "code",
"execution_count": 202,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wan (Chinese): 55.373172760009766\n",
"Wei (Chinese): 63.00840759277344\n",
"Won (Chinese): 96.11245727539062\n",
"Gwang (Korean): 103.33255004882812\n",
"Gwock (Chinese): 118.81124114990234\n",
"Weng (Chinese): 131.87939453125\n",
"Wane (English): 141.58985900878906\n",
"Twigg (English): 147.58078002929688\n",
"Wain (English): 153.88088989257812\n",
"Gowing (English): 156.93133544921875\n"
]
}
],
"source": [
"closest('Wu')"
]
},
{
"cell_type": "code",
"execution_count": 214,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cheryshev (Russian): 4.010871410369873\n",
"Dobryshev (Russian): 23.908416748046875\n",
"Chalyshev (Russian): 23.938615798950195\n",
"Chanyshev (Russian): 28.119571685791016\n",
"Cherushov (Russian): 28.66720962524414\n",
"Tchanyshev (Russian): 30.515729904174805\n",
"Chehov (Russian): 30.542831420898438\n",
"Tchalyshev (Russian): 34.40039825439453\n",
"Yachmentsev (Russian): 36.298770904541016\n",
"Jerebyatiev (Russian): 36.888675689697266\n"
]
}
],
"source": [
"closest('Chebyshev')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment