Skip to content

Instantly share code, notes, and snippets.

@joshfp
Last active October 10, 2020 16:02
Show Gist options
  • Star 9 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save joshfp/b62b76eae95e6863cb511997b5a63118 to your computer and use it in GitHub Desktop.
Save joshfp/b62b76eae95e6863cb511997b5a63118 to your computer and use it in GitHub Desktop.
Fast.ai p1v1: class 4
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ULMFiT: Train spanish LM"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from fastai import *\n",
"from fastai.text import *"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"PATH = Path('~/datasets/wikimedia').expanduser()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Raw data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"wiki_file = 'wiki.es.txt'"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"row_list = []\n",
"with open(PATH/wiki_file, 'r') as f:\n",
" for i, line in enumerate(f):\n",
" if len(line.split()) > 150: # dismiss lines with less than 150 words\n",
" row_list.append(line)\n",
"df = pd.DataFrame(row_list, columns=['text'])"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>el usuario ha decidido pedir el bloqueo indef...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>en regresó a las tablas con openheart el trián...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>tradicionalmente , hampstead , como otros muni...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>aniplex comenzó la distribución de mercancía o...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>por fin , en , el perú y ecuador , mediante el...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" text\n",
"0 el usuario ha decidido pedir el bloqueo indef...\n",
"1 en regresó a las tablas con openheart el trián...\n",
"2 tradicionalmente , hampstead , como otros muni...\n",
"3 aniplex comenzó la distribución de mercancía o...\n",
"4 por fin , en , el perú y ecuador , mediante el..."
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"131465995"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# number of tokens\n",
"sum(df.text.str.split().apply(lambda x: len(x)))"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
"val_sz = int(0.1 * len(df))\n",
"shuffled = np.random.permutation(len(df))\n",
"trn_df = df.iloc[shuffled][val_sz:]\n",
"val_df = df.iloc[shuffled][:val_sz]"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"trn_df.to_csv(PATH/'train.csv', index=None)\n",
"trn_df.to_csv(PATH/'valid.csv', index=None)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load data"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [],
"source": [
"trn_df = pd.read_csv(PATH/'train.csv')\n",
"trn_df = pd.read_csv(PATH/'valid.csv')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"data_lm = TextLMDataBunch.from_df(PATH, trn_df, val_df, text_cols='text', \n",
" tokenizer=Tokenizer(lang='es'), bs=48)"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
"pickle.dump(data_lm.train_ds.vocab.itos, open('itos.pkl', 'wb'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"learn = language_model_learner(data_lm, drop_mult=0.3)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
]
}
],
"source": [
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.recorder.plot(skip_end=12)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total time: 1:46:30\n",
"epoch train_loss valid_loss accuracy\n",
"1 4.168079 4.087315 0.286290 (1:46:30)\n",
"\n"
]
}
],
"source": [
"learn.fit_one_cycle(1, 2e-3, moms=(0.8, 0.7))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"learn.save('weights-1')"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total time: 7:06:30\n",
"epoch train_loss valid_loss accuracy\n",
"1 4.175116 4.097802 0.281919 (1:46:26)\n",
"2 4.071382 4.027313 0.287236 (1:46:50)\n",
"3 3.975520 3.912925 0.299030 (1:46:36)\n",
"4 3.921592 3.863933 0.305236 (1:46:37)\n",
"\n"
]
}
],
"source": [
"learn.fit_one_cycle(4, 2e-3, moms=(0.8, 0.7))"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"learn.save('weights-5')"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
]
}
],
"source": [
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.recorder.plot(0,30)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total time: 1:46:36\n",
"epoch train_loss valid_loss accuracy\n",
"1 3.898500 3.861971 0.305301 (1:46:36)\n",
"\n"
]
}
],
"source": [
"learn.fit_one_cycle(1, 1e-6, moms=(0.8, 0.7))"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"learn.save('weights-6')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:fastai]",
"language": "python",
"name": "conda-env-fastai-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@joshfp
Copy link
Author

joshfp commented Jun 29, 2019

make gist public

@kontrabas380
Copy link

kontrabas380 commented Jan 29, 2020

Hello! You've done nice job and I've got a question. When you finish training this model, how can you predict one example? It's not working with .predict(example).
I've done this with two AWD_LSTM networks, but in the end I've met an issue with this error while making prediction:
AttributeError: 'ConcatDataset' object has no attribute 'set_item'

Best regards

@ascientist
Copy link

Hello! You've done nice job and I've got a question. When you finish training this model, how can you predict one example? It's not working with .predict(example).
I've done this with two AWD_LSTM networks, but in the end I've met an issue with this error while making prediction:
AttributeError: 'ConcatDataset' object has no attribute 'set_item'

Best regards

Same problem here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment