Skip to content

Instantly share code, notes, and snippets.

@xapss
Created March 19, 2018 22:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xapss/23ec632a91370976175d6ce9526816a2 to your computer and use it in GitHub Desktop.
Save xapss/23ec632a91370976175d6ce9526816a2 to your computer and use it in GitHub Desktop.
Combining FastAI's lessons
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Goal:\n",
" \n",
"To predict the score of the review, using regular text + other embeddings + structured data."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Disclaimer: Don't mind the learning results below. It's purely meant to illustrate flow."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load dependencies"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%reload_ext autoreload\n",
"%autoreload 2\n",
"%matplotlib inline\n",
"\n",
"import sys\n",
"sys.path.append('/home/ubuntu/fastai/')\n",
"\n",
"from fastai.imports import *\n",
"from fastai.torch_imports import *\n",
"from fastai.core import *\n",
"from fastai.model import fit\n",
"from fastai.dataset import *\n",
"\n",
"from fastai.structured import *\n",
"from fastai.column_data import *\n",
"\n",
"import torchtext\n",
"from torchtext import vocab, data, datasets\n",
"from torchtext.datasets import language_modeling\n",
"\n",
"from fastai.rnn_reg import *\n",
"from fastai.rnn_train import *\n",
"from fastai.nlp import *\n",
"from fastai.lm_rnn import *\n",
"from fastai.text import *\n",
"\n",
"from fastai.learner import *\n",
"\n",
"import torchtext\n",
"from torchtext import vocab, data\n",
"from torchtext.datasets import language_modeling\n",
"\n",
"from fastai.rnn_reg import *\n",
"from fastai.rnn_train import *\n",
"from fastai.nlp import *\n",
"from fastai.lm_rnn import *\n",
"\n",
"import dill as pickle"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Download dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Big thanks to Julian McAuley from UCSD! More @ http://jmcauley.ucsd.edu/data/amazon/"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"link = \"http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Musical_Instruments_5.json.gz\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2018-03-19 22:33:12-- http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Musical_Instruments_5.json.gz\n",
"Resolving snap.stanford.edu (snap.stanford.edu)... 171.64.75.80\n",
"Connecting to snap.stanford.edu (snap.stanford.edu)|171.64.75.80|:80... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 2460495 (2.3M) [application/x-gzip]\n",
"Saving to: 'reviews_Musical_Instruments_5.json.gz.10'\n",
"\n",
"reviews_Musical_Ins 100%[===================>] 2.35M 8.78MB/s in 0.3s \n",
"\n",
"2018-03-19 22:33:12 (8.78 MB/s) - 'reviews_Musical_Instruments_5.json.gz.10' saved [2460495/2460495]\n",
"\n"
]
}
],
"source": [
"!wget {link}"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import gzip\n",
"\n",
"def parse(path):\n",
" g = gzip.open(path, 'rb')\n",
" for l in g:\n",
" yield eval(l)\n",
"\n",
"def getDF(path):\n",
" i = 0\n",
" df = {}\n",
" for d in parse(path):\n",
" df[i] = d\n",
" i += 1\n",
" return pd.DataFrame.from_dict(df, orient='index')\n",
"\n",
"df = getDF('reviews_Musical_Instruments_5.json.gz')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"df['date'] = df['reviewTime'].apply(lambda x: datetime.datetime.strptime(x, '%m %d, %Y'))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"add_datepart(df, 'date', drop=False)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# adding continuous variables\n",
"df['secondsSpentOnReview'] = np.random.randint(1, 200, df.shape[0])\n",
"df['amountSpentOnAmazon'] = np.random.randint(100, 2000, df.shape[0])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>reviewerID</th>\n",
" <th>asin</th>\n",
" <th>reviewerName</th>\n",
" <th>helpful</th>\n",
" <th>reviewText</th>\n",
" <th>overall</th>\n",
" <th>summary</th>\n",
" <th>unixReviewTime</th>\n",
" <th>reviewTime</th>\n",
" <th>date</th>\n",
" <th>...</th>\n",
" <th>Dayofyear</th>\n",
" <th>Is_month_end</th>\n",
" <th>Is_month_start</th>\n",
" <th>Is_quarter_end</th>\n",
" <th>Is_quarter_start</th>\n",
" <th>Is_year_end</th>\n",
" <th>Is_year_start</th>\n",
" <th>Elapsed</th>\n",
" <th>secondsSpentOnReview</th>\n",
" <th>amountSpentOnAmazon</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>A2IBPI20UZIR0U</td>\n",
" <td>1384719342</td>\n",
" <td>cassandra tu \"Yeah, well, that's just like, u...</td>\n",
" <td>[0, 0]</td>\n",
" <td>Not much to write about here, but it does exac...</td>\n",
" <td>5.0</td>\n",
" <td>good</td>\n",
" <td>1393545600</td>\n",
" <td>02 28, 2014</td>\n",
" <td>2014-02-28</td>\n",
" <td>...</td>\n",
" <td>59</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>1393545600</td>\n",
" <td>28</td>\n",
" <td>1080</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>A14VAT5EAX3D9S</td>\n",
" <td>1384719342</td>\n",
" <td>Jake</td>\n",
" <td>[13, 14]</td>\n",
" <td>The product does exactly as it should and is q...</td>\n",
" <td>5.0</td>\n",
" <td>Jake</td>\n",
" <td>1363392000</td>\n",
" <td>03 16, 2013</td>\n",
" <td>2013-03-16</td>\n",
" <td>...</td>\n",
" <td>75</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>1363392000</td>\n",
" <td>124</td>\n",
" <td>1223</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2 rows × 25 columns</p>\n",
"</div>"
],
"text/plain": [
" reviewerID asin \\\n",
"0 A2IBPI20UZIR0U 1384719342 \n",
"1 A14VAT5EAX3D9S 1384719342 \n",
"\n",
" reviewerName helpful \\\n",
"0 cassandra tu \"Yeah, well, that's just like, u... [0, 0] \n",
"1 Jake [13, 14] \n",
"\n",
" reviewText overall summary \\\n",
"0 Not much to write about here, but it does exac... 5.0 good \n",
"1 The product does exactly as it should and is q... 5.0 Jake \n",
"\n",
" unixReviewTime reviewTime date ... Dayofyear \\\n",
"0 1393545600 02 28, 2014 2014-02-28 ... 59 \n",
"1 1363392000 03 16, 2013 2013-03-16 ... 75 \n",
"\n",
" Is_month_end Is_month_start Is_quarter_end Is_quarter_start \\\n",
"0 True False False False \n",
"1 False False False False \n",
"\n",
" Is_year_end Is_year_start Elapsed secondsSpentOnReview \\\n",
"0 False False 1393545600 28 \n",
"1 False False 1363392000 124 \n",
"\n",
" amountSpentOnAmazon \n",
"0 1080 \n",
"1 1223 \n",
"\n",
"[2 rows x 25 columns]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head(2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tokenize Text"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"sl=1000\n",
"vocab_size=200000"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"The product does exactly as it should and is quite affordable . I did not realized it was double screened until it arrived , so it was even better than I had expected . As an added bonus , one of the screens carries a small hint of the smell of an old grape candy I used to buy , so for reminiscent 's sake , I can not stop putting the pop filter next to my nose and smelling it after recording . : DIf you needed a pop filter , this will work just as well as the expensive ones , and it may even come with a pleasing aroma like mine did!Buy this product ! :]\""
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# sample test\n",
"' '.join(Tokenizer().spacy_tok(df[:3].reviewText.values[1]))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"TEXT = data.Field(lower=True, tokenize=Tokenizer().spacy_tok)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"train_test_ratio = .8\n",
"cut = int(train_test_ratio*len(df))\n",
"df_train = df[:cut]\n",
"df_test = df[cut:]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"bs=64\n",
"bptt=70\n",
"PATH='output'"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"md = LanguageModelData.from_dataframes(PATH, TEXT, col='reviewText', train_df=df_train, val_df=df_test, test_df=df_test, bs=bs, bptt=bptt, min_freq=10)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mkdir: cannot create directory 'output': File exists\r\n"
]
}
],
"source": [
"!mkdir {PATH}"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"pickle.dump(TEXT, open(f'{PATH}/TEXT.pkl','wb'))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(183, 3783, 1, 826656)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(md.trn_dl), md.nt, len(md.trn_ds), len(md.trn_ds[0].text)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['<unk>', '<pad>', '.', 'the', ',', 'i', 'a', 'and', 'it', 'to']"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"TEXT.vocab.itos[:10]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Language Model"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"em_sz = 50 # size of each embedding vector\n",
"nh = 500 # number of hidden activations per layer\n",
"nl = 3 # number of layers"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"opt_fn = partial(optim.Adam, betas=(0.7, 0.99))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"learner = md.get_model(opt_fn, em_sz, nh, nl,\n",
" dropouti=0.2, dropout=0.2, wdrop=0.2, dropoute=0.1, dropouth=0.1)\n",
"learner.reg_fn = partial(seq2seq_reg, alpha=2, beta=1)\n",
"learner.clip=0.3"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "00eea86ecd754ebba606f6122d8222e7",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
"<p>\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
" Widgets Documentation</a> for setup instructions.\n",
"</p>\n",
"<p>\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"</p>\n"
],
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" 99%|█████████▉| 182/183 [00:19<00:00, 9.15it/s, loss=28.3]"
]
}
],
"source": [
"lrf=learner.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEOCAYAAABmVAtTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvFvnyVgAAIABJREFUeJzt3XmYXNV55/Hv2/uqlnrTvoIWwGAwgiDAQdiGYMc29gTb4Uk82CZRcGbiJcvEiT1JnOSZeGwmGSeOF2ITkTghNkswi23hMEgEgwSSDFqsFS3Q2rp6kXrv6q5654+63Wo13VK3VLduVfXv8zz11K1Tt+q8p6u73j733HOuuTsiIiIABVEHICIi2UNJQUREhikpiIjIMCUFEREZpqQgIiLDlBRERGSYkoKIiAxTUhARkWFKCiIiMkxJQUREhhVFHcBE1NfX+6JFi6IOQ0Qkp2zZsqXF3Rsm85qcSAqLFi1i8+bNUYchIpJTzOzwZF8T2uEjM5tvZs+a2S4z22lmnw7Ka83sJ2a2L7ifEVYMIiIyOWGOKQwCv+fulwDXAf/NzC4FPgc84+5LgWeCxyIikgVCSwrufszdtwbbncAuYC5wO/BAsNsDwAfCikFERCYnI2cfmdki4CpgEzDT3Y9BKnEAjeO8Zo2ZbTazzbFYLBNhiohMeaEnBTOrAh4BPuPuHRN9nbvf5+4r3X1lQ8OkBs9FROQ8hZoUzKyYVEL4F3d/NCg+YWazg+dnA81hxiAiIhMX5tlHBnwH2OXufz3iqceBu4Ltu4AfhBWDiEiuOtU7wI93HKelqz+j9YbZU7gB+CjwDjN7Jbi9B/gScIuZ7QNuCR6LiMgI+5s7uee7W9h5dMJH3dMitMlr7v48YOM8/c6w6hURyQetXXEA6ipLMlqv1j4SEclCrd2ppFCrpCAiIm1KCiIiMqS1K05lSSFlxYUZrVdJQUQkC7V191NbldleAigpiIhkpdbuOLWVpRmvV0lBRCQLtXXHqc/weAIoKYiIZKXWrnjGB5lBSUFEJOu4O23dcY0piIgIdPUPEk8kMz5xDZQURESyzuk5ChpoFhGZ8oZmM6unICIitHVFM5sZlBRERLJOVEtcgJKCiEjWaelOXUOhTmcfiYhIW1ec8uJCKkpCu7rBuJQURESyTFt3NBPXQElBRCTrtHbHIzl0BEoKIiJZRz0FEREZpqQgIiJAat2jlq7+SCaugZKCiEhWOdkzQP9gkrqqzC9xASEmBTO738yazWzHiLIrzWyjmb1iZpvN7Nqw6hcRyUUb9sYAuHZxbST1h9lTWAvcNqrsy8AX3f1K4E+CxyIiEli38ziN1aVcOW96JPWHlhTc/TmgbXQxMC3YrgGOhlW/iEiu6RtIsGFvjFsunUlBgUUSQ6any30GWGdm95JKSNdnuH4Rkaz1/L4WeuIJbr1sVmQxZHqg+ZPAZ919PvBZ4Dvj7Whma4Jxh82xWCxjAYqIROXpnx+nurSIVUvqIosh00nhLuDRYPshYNyBZne/z91XuvvKhoaGjAQnIhKln+5v5ReXNVBSFN2JoZmu+ShwU7D9DmBfhusXEclKiaRzvKOPhXUVkcYR2piCmT0IrAbqzawJ+FPgN4GvmlkR0AesCat+EZFc0t4TJ5F0GqujmZ8wJLSk4O53jvPU1WHVKSKSq5o7UtdQaJxWFmkcmtEsIpIFmjv7ACLvKSgpiIhkgVhnqqfQoKQgIiLNQVJorNbhIxGRKS/W2U91aRHlJYWRxqGkICKSBZo7+2iYFu2hI1BSEBHJCrHOfhoiWi57JCUFEZEs0NzZH/npqKCkICISOXenuaM/8tNRQUlBRCRy3fEEvQMJJQUREYHmjtTEtajnKICSgohI5LJljgIoKYiIRG5oNnOjTkkVEZGhnoJOSRUREZo7+ygpLGB6RXHUoSgpiIhELdbZT0N1KWYWdShKCiIiURtKCtlASUFEJGLNHUoKIiISaO7sy4qJa6CkICISqfhgkvaegayYowBKCiIikWrpyp45CqCkICISqWyaowAhJgUzu9/Mms1sx6jy3zGzPWa208y+HFb9IiK5IJtmM0O4PYW1wG0jC8zsZuB24Ap3vwy4N8T6RUSyXnNnajG8vB9TcPfngLZRxZ8EvuTu/cE+zWHVLyKSC5o7+jGDuqqSqEMBMj+msAx4u5ltMrMNZnZNhusXEckqsa5+aitKKC7MjiHeogjqmwFcB1wDfN/Mlri7j97RzNYAawAWLFiQ0SBFRDIlmyauQeZ7Ck3Ao57yEpAE6sfa0d3vc/eV7r6yoaEho0GKiGRKrLMvK67NPCTTSeEx4B0AZrYMKAFaMhyDiEjWiHX2Z83pqBDi4SMzexBYDdSbWRPwp8D9wP3Baapx4K6xDh2JiEwF7k6sqz9rTkeFEJOCu985zlO/HladIiK5pL1ngIGEZ826R6AZzSIikcm2OQqgpCAiEpmh2cxT+ewjEREJNHcES1woKYiISHOWrXsESgoiIpE50dFHVWkRFSWZnkc8PiUFEZGIHDnZy9zp5VGHcQYlBRGRiBxp72XeDCUFEREBmtp7mKukICIiHX0DdPQNqqcgIiKpQ0cAc6dXRBzJmZQUREQi0BQkBfUURESEI+09ABpTEBGRVE+hrLiAusrsuAznECUFEZEIHDnZy7wZFZhZ1KGcQUlBRCQCTe3ZN3ENlBRERCLR1N6TdYPMoKQgIpJx3f2DtPcMZN0gMygpiIhk3JGTQ6ejZtccBVBSEBHJuCNZOkcBlBRERDKuKZijME8DzSIi0tTeS0lRAfVV2XNxnSGhJQUzu9/Mms1sxxjP/b6ZuZnVh1W/iEi2agquo1BQkF1zFCDcnsJa4LbRhWY2H7gFeD3EukVEslZTFl5HYUhoScHdnwPaxnjqb4D/AXhYdYuIZLMjWTpHATI8pmBm7weOuPurmaxXRCRb9A0kaOmKZ+VsZoCMXS3azCqAzwO3TnD/NcAagAULFoQYmYhI5pxeMjv75ihAZnsKFwGLgVfN7BAwD9hqZrPG2tnd73P3le6+sqGhIYNhioiEZ2jiWjbOZoYM9hTcfTvQOPQ4SAwr3b0lUzGIiERteI5CliaFME9JfRB4EVhuZk1mdndYdYmI5Iqm9l6KC43G6rKoQxlTaD0Fd7/zHM8vCqtuEZFsdaS9l9k15RRm4RwFmGBPwcw+bWbTLOU7ZrbVzCY0YCwiIqdl65LZQyZ6+OgT7t5B6syhBuDjwJdCi0pEJE8dOZmdF9cZMtGkMNTPeQ/wj8E8g+zs+4iIZKn+wQQnOvqz9nRUmHhS2GJmT5NKCuvMrBpIhheWiEj+OXqyD8je01Fh4gPNdwNXAgfcvcfMakkdQhIRkQk61NoNwILa3O8prAL2uPtJM/t14AvAqfDCEhHJP/tOdAKwbGZVxJGMb6JJ4RtAj5m9ldRidoeBfwotKhGRPLT3RBcN1aVMryiJOpRxTTQpDLq7A7cDX3X3rwLV4YUlIpJ/9p3ozOpeAkw8KXSa2R8BHwWeMrNCoDi8sERE8ou7s6+5i6WN2f3/9ESTwkeAflLzFY4Dc4GvhBaViEieOXKyl554gqX50FMIEsG/ADVm9l6gz901piAiMkH7TnQBsGxmHvQUzOzDwEvAh4APA5vM7I4wAxMRySd7gzOPljZmd09hovMUPg9c4+7NAGbWAPwH8HBYgYmI5JN9zdl/5hFMfEyhYCghBFon8VoRkSkvF848gon3FH5sZuuAB4PHHwF+GE5IIiL5JZlMnXn04ZXzow7lnCaUFNz9D8zsV4AbSC2Ed5+7/3uokYmI5ImhM4+yfZAZJnGRHXd/BHgkxFhERPLS/uahM49y/PCRmXUCPtZTgLv7tFCiEhHJI6fPPMrxnoK7Z38LRESy3N4TXTRWl1JTkf0LQegMIhGRkO1r7syJ8QRQUhARCVUy6exv7sr65S2GhJYUzOx+M2s2sx0jyr5iZrvNbJuZ/buZTQ+rfhGRbDC85lEOjCdAuD2FtcBto8p+ArzF3a8A9gJ/FGL9IiKR29ec/RfWGSm0pODuzwFto8qedvfB4OFGYF5Y9YuIZIO9wUJ4SzWmcE6fAH4UYf0iIqHbd6KLmdNKqSnP/jOPIKKkYGafBwZJLcc93j5rzGyzmW2OxWKZC05EJI1y6cwjiCApmNldwHuBXwsu8Tkmd7/P3Ve6+8qGhobMBSgikiZ9Awl2H+/kktm5M893wstcpIOZ3Qb8IXCTu/dksm4RkUzb1nSK+GCSaxbVRh3KhIV5SuqDwIvAcjNrMrO7ga8B1cBPzOwVM/tmWPWLiETtpYOtAFyzaEbEkUxcaD0Fd79zjOLvhFWfiEi22XSwjRWzqrP+wjojaUaziEgIBhNJthxu59rFuXPoCJQURERCsfNoBz3xhJKCiIjAy4dSc3evzaFBZlBSEBEJxUsH21hYV0HjtLKoQ5kUJQURkRBsP3KKq+bn3pqfSgoiImnW0tXPsVN9vGVuTdShTJqSgohImu082gHAZXOUFEREprwdR04BcOmc3FneYoiSgohImu08eooFtRU5szLqSEoKIiJptuNIB2+Zm3u9BFBSEBFJq1M9A7ze1pOT4wmgpCAiklY7j6XGE3LxzCNQUhARSaudR4bOPNLhIxGRKW/TwTbm15ZTX1UadSjnRUlBRCRN+gcTvPBaCzcty92rRSopiIikyZZD7fTEE9y0rDHqUM6bkoKISJqs3xujuNC4/qK6qEM5b0oKIiJpsmFPjGsW1VJZGtpFLUOnpCAikgZHT/ay50Qnq5fn7ngCKCmIiKTFhr0xAFYvz93xBFBSEBFJiw17YsyuKWNpY1XUoVyQ0JKCmd1vZs1mtmNEWa2Z/cTM9gX3M8KqX0QkUwYSSX66v4XVyxsws6jDuSBh9hTWAreNKvsc8Iy7LwWeCR6LiOS0rYfb6ewfzOn5CUNCSwru/hzQNqr4duCBYPsB4ANh1S8ikinr98YoKjCuv7g+6lAuWKbHFGa6+zGA4D63R2RERID1e2K8beEMppXl3vUTRsvagWYzW2Nmm81scywWizocEZExnejoY9exjpw/FXVIppPCCTObDRDcN4+3o7vf5+4r3X1lQ0N+/LBFJP+s23kcgHeumBlxJOmR6aTwOHBXsH0X8IMM1y8iklZPvHqUZTOrWD6rOupQ0iLMU1IfBF4ElptZk5ndDXwJuMXM9gG3BI9FRHLS0ZO9vHyonfddMSfqUNImtAU63P3OcZ56Z1h1iohk0lPbjgHwvrfmT1LI2oFmEZFs98S2o1wxr4ZF9ZVRh5I2SgoiIufhUEs325pO5dWhI1BSEBE5L09uOwrAL18xO+JI0ktJQUTkPDz+6lGuWTSDOdPLow4lrZQUREQmac/xTvae6MqrAeYhSgoiIpP0xKtHKTB4z+X5degIlBRERCbF3Xli21Guv6ie+qrSqMNJOyUFEZFJ2Pp6O4dbe7j9yvw7dARKCiIik/Lo1iOUFRfw7jw8dARKCiIiE9Y/mODJbcf4pctmUVUa2oIQkVJSEBGZoGd3N3Oqd4APXjU36lBCo6QgIjJBj249Qn1VKTfmwRXWxqOkICIyAe3dcZ7d08wHrpxDUWH+fnXmb8tERNLoye3HGEg4H3xb/h46AiUFEZEJeXRrE8tnVnPp7GlRhxIqJQURkXM42NLNz14/yQffNhczizqcUCkpiIicw2M/O4IZfODK/D50BEoKIiLn9MPtx7h2US2zasqiDiV0SgoiImexv7mTfc1debn43ViUFEREzuJH248DcNtbZkUcSWbkdVJwd9w96jBEJIf9cMdxrl44g5nT8v/QEUAki3eY2WeB3wAc2A583N370l3Pn/xgJ/+88fCouoP7VBwjtofKLVUwRvl4r2Ws8tNvg5mN2D79gjP3Ceoe47UjY59IPIUFRmGBUVRoFBYUUDT0+Iz7oLzwzeXFhUZ5cSFlwa28uIDykpGPCykvKQz2KaCsuJDqsmKmlRXl/ZkZMrUcaulm17EOvvDLl0QdSsZkPCmY2VzgU8Cl7t5rZt8HfhVYm+663rGikdrKEiCVfQh6DX56E8dHbJ9ZzhnlwWtPFwfbp8sZLvdx9xlZzsjy843njH1S5Ul3BhNOIukMJofukwwkkvQOBI8Tp8vP3M+JDybpG0jQP5ic4E86pbjQqK0sobaylPqqkmC7hIbqUuZOL2fejArmzyinobpUyUNywtoXDlFgU+fQEUTUUwjqLTezAaACOBpGJTevaOTmFY1hvPWUkEw6/YNJegcS9A0k6B1I0BtP0D+YoDc+ojyeoKNvgNbuOG1dcVq7+2ntjnO4tYe27jhd/YNnvG9lSSEXz6xmWWMVy2dVc9mcGi6fV5O3q05KbtpzvJN/3niYO69dwLwZFVGHkzEZ/yt09yNmdi/wOtALPO3uT2c6Djm3ggJLHSYqKbyg9+mJD3KkvZem9l7eaO/hQKybvSc6eXZPjIe2NAGpw2IXNVRx5fzpXLu4llVL6pg3o1w9ComEu/PnT+6kqrSI37t1edThZFQUh49mALcDi4GTwENm9uvu/t1R+60B1gAsWLAg02FKGlWUFLF0ZjVLZ1a/6bnWrn62HznFtqZTvPrGSf7f7mYeDhLFnJoyrltSx3VL6viFJbUsqK1QkpCMePzVo/x0fytffP9lw4egpwrL9Nk5ZvYh4DZ3vzt4/F+B69z9t8d7zcqVK33z5s2ZClEi5O7sa+5i04FWNh5oY+OBVlq74wDMrinjxovruWl5A2+/uIGaiuKIo5V8FOvs59a/2cDCukoe+eT1FBbk7j8iZrbF3VdO5jVRHMR9HbjOzCpIHT56J6BvfAFSZ1Etm1nNspnVfHTVItyd12JdvHigjY2vtbJu53Ee2tJEgcHbFsxg9fIGVi9v5LI509SLkAvm7nzhse10xxPc+6ErcjohnK8oxhQ2mdnDwFZgEPgZcF+m45DcYGZc3FjNxY3VfPS6hQwmkrzadJINe2I8uyfGvU/v5d6n99JYXcpNyxq4eUUjNy6tZ1qZehEyeU9uO8a6nSf43LtXcHHjmw93TgUZP3x0PnT4SMYT6+xnw94Y6/c089zeGB19gxQWGFcvnMHNyxtZvbyBFbOq1YuQc2rp6ueWv97AgrpKHrlnVV5cSOd8Dh8pKUjeGEwkeeWNkzy7p5n1e2LsPNoBwKxpZcOHmW5cWq9TX+VN3J17vruFZ3fHeOpTN455UkQuypUxBZFQFBUWsHJRLSsX1fIHv7SCEx19bNgTY/3eZp7adox/e/kNiguNlQtruXlFAzcvb+Tixir1IoSvr3+NdTtP8MfvWZE3CeF8qacgU8JAIsmWw+2s35M61LT7eCcAc6eXD/cirr+ojkr1Iqacp3ce57e+u4X3v3UO//cjV+bVPwk6fCQyQcdO9bJ+T4xndzfz0/0tdMcTlBQW8AtLaocHrJfUV+bVF4S82fP7WvjEAy9zyaxqvvdbqygrvrCJmtlGSUHkPMQHk2w+1Mb6vakksa+5C4B5M8pZtaSO6y+uY9WS+ilxgZWpZOOBVu66/yUW11fy4G9ex4w8nKSmpCCSBk3tPTy7J8bz+2JsPNDGqd4BABbXV7LqojpWBbOsG6pLI45Uztfu4x186BsvMrOmjO+tuY66qvz8LJUURNIsmXR2He/gxddaefG1Vl462EZnsMDfsplVrFpSx6qL6viFxXV5+Z9mPmpq7+FD33yRpDuP/vYNzJ1eHnVIoVFSEAnZYCLJzqMdvPBaKy8eaOXlg230DiQwg0tmTePaxbVcPreGt86vYXF91ZScEZvNtjWd5O4HNtM3kOB7a1Zx6ZxpUYcUKiUFkQyLDybZ1nQy1ZM40Morb5ykJ54AUkuEv2VuDVfMq+GKedO5Yl6NFvWLSCLpPPDCIb68bjf1VaX848eumRKnniopiEQskUyt1bSt6RTbmk6yrekUPz/WQTy4YFFNeXGQJGq4fG4qUcyuKVOiCEl8MMmPdhzjWxsO8PNjHaxe3sBX7njrlBkPUlIQyUIDiSR7jncGS4SnEsWe450MJlN/e7WVJSyfWc2i+gpmTiujuLCAksIC5kwvZ0FtBfNry6kpL1bimITWrn7+ZdPrfHfjYZo7+1lcX8nv3rKM914xe0r9HJUURHJE30CCXcc62NZ0il3HOth9vJOm9l5auvrH3L+ypJDZQZK4uLHqjJsW/zutu3+Qf3rxMH//7H66+ge5aVkDH7thETctbaBgCo7vaJkLkRxRVlzIVQtmcNWCGWeUDyaSJNzpG0jS1N7DG229vNHWw9FTvRw72cfBlm6e39dCPHH6+tkN1aUsrK0IehWp+wV1FSysrcj762HHB1PrXf10fwsvvNbCK2+cZCDhvOuSRv7wNi1ZcT6UFESySFFhAUVAaVEhNeU1XDan5k37DCaSvNHey/7mLl6LdfFacxdvtPew6WAb//7KEUZ2/suKC5g/o4KFdRUsqK1kQW05DdVlNFSXUl9VwsxpZTmztEd3/yAHW7o52NLNrmMdbD9yii2H2+mJp87+unxuDXffuIRbLm3k6oW1UYebs3Ljt0FEhhUVFrC4vpLF9ZXcwswznusfTHD0ZB+vt/Xwems3h1t7UtttPfx0fyu9A4k3vV91aRGzasqYVVNGXWUJtZWl1FYWM6OyhNqKktR9ZQkzKkqYUVGctiWl3Z2eeIK27jit3XHauvtp7YrT1h0fUZa6P36qlxMdpw+tFRUYFzdWccfV87j+onpWLanTlfjSRElBJI+UFhUOJwxoOOM5d6etO06sq5+WzjjNnX00d/Zz/FQfx071cryjn0Ot3bR3D9AVTNAbS015MTMqiikvKaK0qCB1Ky6ktKiAogJjIOEMJpMkks5AIslAInU/mHAKC4ykO+3Bl33/YHLMOkqKCqirLKGuKpWkljZWDbdrUV0lSxoq826domyhpCAyRZgZdVWlqSUdZp193/7BBCd7BmjrjtPeHaetJ7jvHqCtu5+2ngH6BhL0DybpH0jQ0TtA/2CSwUSSosICiguNwgKjuLCA8uJCqsuKKCowEsEZV8tnVQ/3Sk5/+ZdQV1lKbVUJlSWFeT0Wks2UFETkTUqLCpk5rZCZ07QI4FST+9ebExGRtFFSEBGRYUoKIiIyTElBRESGRZIUzGy6mT1sZrvNbJeZrYoiDhEROVNUZx99Ffixu99hZiVARURxiIjICBlPCmY2DfhF4GMA7h4H4pmOQ0RE3iyKw0dLgBjwj2b2MzP7tplVjt7JzNaY2WYz2xyLxTIfpYjIFJTxpbPNbCWwEbjB3TeZ2VeBDnf/n2d5TQw4fB7V1QCnLnC/sZ4bXXa2x+Nt1wMtE4jtbCbSvsm2baxyte/N2xfavrB+N8cqV/vO3M7U7+a59stE+xa6+5nrnZyLu2f0RmqC/aERj98OPBVSXfdd6H5jPTe67GyPz7K9ORPtm2zb1L7MtC+s302179zty9TvZq62L+OHj9z9OPCGmS0Pit4J/Dyk6p5Iw35jPTe67GyPx9tOh4m832TbNla52jfxeCYqrN/NscrVvsnFNBF5275IrrxmZlcC3wZKgAPAx929PeOBRMjMNvskr4iUS9S+3JbP7cvntsGFty+SU1Ld/RUgbz+UCbov6gBCpvbltnxuXz63DS6wfTlxjWYREckMLXMhIiLDlBRERGSYkoKIiAxTUshCZlZpZlvM7L1Rx5JuZnaJmX0zWBDxk1HHk25m9gEz+wcz+4GZ3Rp1POlmZkvM7Dtm9nDUsaRL8Pf2QPC5/VrU8aTbZD8zJYU0MrP7zazZzHaMKr/NzPaY2X4z+9wE3uoPge+HE+X5S0f73H2Xu98DfJgsOwMtTe17zN1/k9TaXh8JMdxJS1P7Drj73eFGeuEm2db/AjwcfG7vz3iw52Ey7ZvsZ6akkF5rgdtGFphZIfD3wLuBS4E7zexSM7vczJ4cdWs0s3eRmsx3ItPBT8BaLrB9wWveDzwPPJPZ8M9pLWloX+ALweuyyVrS175st5YJthWYB7wR7JbIYIwXYi0Tb9+kRLV0dl5y9+fMbNGo4muB/e5+AMDM/g243d3/CnjT4SEzuxmoJPWh9prZD909GWrgE5SO9gXv8zjwuJk9BfxreBFPTpo+PwO+BPzI3beGG/HkpOvzywWTaSvQRCoxvEKO/KM8yfZNasWInPgB5Li5nP4vBFK/gHPH29ndP+/unyH1ZfkP2ZIQzmJS7TOz1Wb2t2b2LeCHYQeXBpNqH/A7wLuAO8zsnjADS5PJfn51ZvZN4Coz+6Owg0uz8dr6KPArZvYN0r9USyaN2b7JfmbqKYTPxig754xBd1+b/lBCMan2uft6YH1YwYRgsu37W+Bvwwsn7SbbvlYgF5LdWMZsq7t3Ax/PdDAhGK99k/rM1FMIXxMwf8TjecDRiGIJg9qX2/K9fSPle1vT0j4lhfC9DCw1s8WWuvTorwKPRxxTOql9uS3f2zdSvrc1Le1TUkgjM3sQeBFYbmZNZna3uw8C/x1YB+wCvu/uO6OM83ypfWpfrsj3tobZPi2IJyIiw9RTEBGRYUoKIiIyTElBRESGKSmIiMgwJQURERmmpCAiIsOUFCTtzKwrA3W8/1zLPIdQ52ozu/48XneVmX072P6YmX0t/dFNnpktGr308hj7NJjZjzMVk0RPSUGyVrAU8Jjc/XF3/1IIdZ5tPbDVwKSTAvDHwN+dV0ARc/cYcMzMbog6FskMJQUJlZn9gZm9bGbbzOyLI8ofs9TV5Xaa2ZoR5V1m9udmtglYZWaHzOyLZrbVzLab2Ypgv+H/uM1sbbDy6gtmdsDM7gjKC8zs60EdT5rZD4eeGxXjejP7X2a2Afi0mb3PzDaZ2c/M7D/MbGawTPE9wGfN7BUze3vwX/QjQfteHuuL08yqgSvc/dUxnltoZs8EP5tnzGxBUH6RmW0M3vPPx+p5WepqYU+Z2atmtsPMPhKUXxP8HF41s5fMrDroEfxn8DPcOlZvx8wKzewrIz6r3xrx9GNA3l2RTMbh7rrpltYb0BXc3wrcR2r1xgLgSeAXg+dqg/tyYAdQFzx24MMj3usQ8DvB9m8D3w62PwZ8LdheCzwU1HEpqTXlAe4gtTx3ATALaAfuGCPe9cDXRzyewenZ/r96llQwAAADYklEQVQB/J9g+8+A3x+x378CNwbbC4BdY7z3zcAjIx6PjPsJ4K5g+xPAY8H2k8CdwfY9Qz/PUe/7K6SWVh96XAOUAAeAa4KyaaRWQq4AyoKypcDmYHsRsCPYXgN8IdguBTYDi4PHc4HtUf9e6ZaZm5bOljDdGtx+FjyuIvWl9BzwKTP7YFA+PyhvJXXlq0dGvc+jwf0WUpdOHMtjnrr2xM/NbGZQdiPwUFB+3MyePUus3xuxPQ/4npnNJvVFe3Cc17wLuNRseMXiaWZW7e6dI/aZDcTGef2qEe35Z+DLI8o/EGz/K3DvGK/dDtxrZv8beNLd/9PMLgeOufvLAO7eAaleBfA1M7uS1M932RjvdytwxYieVA2pz+Qg0AzMGacNkmeUFCRMBvyVu3/rjEKz1aS+UFe5e4+ZrQfKgqf73H30JRH7g/sE4//O9o/YtlH3E9E9YvvvgL9298eDWP9snNcUkGpD71net5fTbTuXCS9E5u57zexq4D3AX5nZ06QO84z1Hp8ldXnXtwYx942xj5Hqka0b47kyUu2QKUBjChKmdcAnzKwKwMzmWuo6vzVAe5AQVgDXhVT/86SuqFUQ9B5WT/B1NcCRYPuuEeWdQPWIx0+TWpUSgOA/8dF2ARePU88LpJY3htQx++eD7Y2kDg8x4vkzmNkcoMfdv0uqJ/E2YDcwx8yuCfapDgbOa0j1IJLAR4GxBvDXAZ80s+LgtcuCHgakehZnPUtJ8oeSgoTG3Z8mdfjjRTPbDjxM6kv1x0CRmW0D/oLUl2AYHiF14ZEdwLeATcCpCbzuz4CHzOw/gZYR5U8AHxwaaAY+BawMBmZ/zhhXt3L33UBNMOA82qeAjwc/h48Cnw7KPwP8rpm9ROrw01gxXw68ZGavAJ8H/tLd48BHgL8zs1eBn5D6L//rwF1mtpHUF3z3GO/3bVLX8t0anKb6LU73ym4GnhrjNZKHtHS25DUzq3L3LjOrA14CbnD34xmO4bNAp7t/e4L7VwC97u5m9qukBp1vDzXIs8fzHHC7u7dHFYNkjsYUJN89aWbTSQ0Y/0WmE0LgG8CHJrH/1aQGhg04SerMpEiYWQOp8RUlhClCPQURERmmMQURERmmpCAiIsOUFEREZJiSgoiIDFNSEBGRYUoKIiIy7P8DUQAb6jVd59cAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learner.sched.plot()"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f46d44b483b14ce28b37b9d36c9030ee",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
"<p>\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
" Widgets Documentation</a> for setup instructions.\n",
"</p>\n",
"<p>\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"</p>\n"
],
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" 49%|████▉ | 90/183 [00:09<00:09, 9.56it/s, loss=11.1] \n",
" 50%|█████ | 92/183 [00:09<00:09, 9.55it/s, loss=10.9]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Exception in thread Thread-4:\n",
"Traceback (most recent call last):\n",
" File \"/home/ubuntu/src/anaconda3/envs/fastai/lib/python3.6/threading.py\", line 916, in _bootstrap_inner\n",
" self.run()\n",
" File \"/home/ubuntu/src/anaconda3/envs/fastai/lib/python3.6/site-packages/tqdm/_tqdm.py\", line 144, in run\n",
" for instance in self.tqdm_cls._instances:\n",
" File \"/home/ubuntu/src/anaconda3/envs/fastai/lib/python3.6/_weakrefset.py\", line 60, in __iter__\n",
" for itemref in self.data:\n",
"RuntimeError: Set changed size during iteration\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch trn_loss val_loss \n",
" 0 6.766552 6.431041 \n",
"\n"
]
},
{
"data": {
"text/plain": [
"[6.4310412]"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# hint: if you run this don't forget to chance the epochs for proper results\n",
"learner.fit(3e-1, 1, wds=1e-6, cycle_len=1, cycle_mult=1, cycle_save_name='amzn-exp-1')"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"learner.save_encoder('amzn-exp-1-enc')"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"learner.load_encoder('amzn-exp-1-enc')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Text Demo"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'To my surprise , the product was'"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m=learner.model\n",
"ss=\"To my surprise, the product was\"\n",
"s = [Tokenizer().spacy_tok(ss)]\n",
"t=TEXT.numericalize(s)\n",
"' '.join(s[0])"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"# Set batch size to 1\n",
"m[0].bs=1\n",
"# Turn off dropout\n",
"m.eval()\n",
"# Reset hidden state\n",
"m.reset()\n",
"# Get predictions from model\n",
"res,*_ = m(t)\n",
"# Put the batch size back to what it was\n",
"m[0].bs=bs"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['.', '<unk>', 'the', ',', 'i', 'a', 'and', 'it', 'to', 'is']"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nexts = torch.topk(res[-1], 10)[1]\n",
"[TEXT.vocab.itos[o] for o in to_np(nexts)]"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"To my surprise, the product was \n",
"\n",
". plate plate dig fender . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . ...\n"
]
}
],
"source": [
"print(ss,\"\\n\")\n",
"for i in range(200):\n",
" n=res[-1].topk(2)[1]\n",
" n = n[1] if n.data[0]==0 else n[0]\n",
" print(TEXT.vocab.itos[n.data[0]], end=' ')\n",
" res,*_ = m(n[0].unsqueeze(0))\n",
"print('...')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Classifier"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['reviewerID', 'asin', 'reviewerName', 'helpful', 'reviewText',\n",
" 'overall', 'summary', 'unixReviewTime', 'reviewTime', 'date', 'Year',\n",
" 'Month', 'Week', 'Day', 'Dayofweek', 'Dayofyear', 'Is_month_end',\n",
" 'Is_month_start', 'Is_quarter_end', 'Is_quarter_start', 'Is_year_end',\n",
" 'Is_year_start', 'Elapsed', 'secondsSpentOnReview',\n",
" 'amountSpentOnAmazon'],\n",
" dtype='object')"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.columns"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"10261"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"emb_vars = ['helpful']\n",
"cat_vars = ['reviewerID', 'asin','Year','reviewText','Month', 'Week',\n",
" 'Day', 'Dayofweek', 'Dayofyear', 'Is_month_end', 'Is_month_start',\n",
" 'Is_quarter_end', 'Is_quarter_start', 'Is_year_end', 'Is_year_start']\n",
"\n",
"contin_vars = ['secondsSpentOnReview','amountSpentOnAmazon']\n",
"\n",
"n = len(df); n"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"for v in cat_vars: df[v] = df[v].astype('category').cat.as_ordered()\n",
"for v in contin_vars: df[v] = df[v].astype('float32')\n",
"dep = 'overall'\n",
"joined = df[cat_vars+contin_vars+emb_vars+[dep, 'date']]"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>reviewerID</th>\n",
" <th>asin</th>\n",
" <th>Year</th>\n",
" <th>reviewText</th>\n",
" <th>Month</th>\n",
" <th>Week</th>\n",
" <th>Day</th>\n",
" <th>Dayofweek</th>\n",
" <th>Dayofyear</th>\n",
" <th>Is_month_end</th>\n",
" <th>Is_month_start</th>\n",
" <th>Is_quarter_end</th>\n",
" <th>Is_quarter_start</th>\n",
" <th>Is_year_end</th>\n",
" <th>Is_year_start</th>\n",
" <th>secondsSpentOnReview</th>\n",
" <th>amountSpentOnAmazon</th>\n",
" <th>helpful</th>\n",
" <th>overall</th>\n",
" <th>date</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>A2IBPI20UZIR0U</td>\n",
" <td>1384719342</td>\n",
" <td>2014</td>\n",
" <td>Not much to write about here, but it does exac...</td>\n",
" <td>2</td>\n",
" <td>9</td>\n",
" <td>28</td>\n",
" <td>4</td>\n",
" <td>59</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>28.0</td>\n",
" <td>1080.0</td>\n",
" <td>[0, 0]</td>\n",
" <td>5.0</td>\n",
" <td>2014-02-28</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>A14VAT5EAX3D9S</td>\n",
" <td>1384719342</td>\n",
" <td>2013</td>\n",
" <td>The product does exactly as it should and is q...</td>\n",
" <td>3</td>\n",
" <td>11</td>\n",
" <td>16</td>\n",
" <td>5</td>\n",
" <td>75</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>124.0</td>\n",
" <td>1223.0</td>\n",
" <td>[13, 14]</td>\n",
" <td>5.0</td>\n",
" <td>2013-03-16</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" reviewerID asin Year \\\n",
"0 A2IBPI20UZIR0U 1384719342 2014 \n",
"1 A14VAT5EAX3D9S 1384719342 2013 \n",
"\n",
" reviewText Month Week Day Dayofweek \\\n",
"0 Not much to write about here, but it does exac... 2 9 28 4 \n",
"1 The product does exactly as it should and is q... 3 11 16 5 \n",
"\n",
" Dayofyear Is_month_end Is_month_start Is_quarter_end Is_quarter_start \\\n",
"0 59 True False False False \n",
"1 75 False False False False \n",
"\n",
" Is_year_end Is_year_start secondsSpentOnReview amountSpentOnAmazon \\\n",
"0 False False 28.0 1080.0 \n",
"1 False False 124.0 1223.0 \n",
"\n",
" helpful overall date \n",
"0 [0, 0] 5.0 2014-02-28 \n",
"1 [13, 14] 5.0 2013-03-16 "
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"joined.head(2)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"10261"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"idxs = get_cv_idxs(n, val_pct=150000/n)\n",
"joined_samp = joined.iloc[idxs].set_index(\"date\")\n",
"samp_size = len(joined_samp); samp_size"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"# Get a smaller sample size\n",
"samp_size = n\n",
"joined_samp = joined.set_index(\"date\")"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>reviewerID</th>\n",
" <th>asin</th>\n",
" <th>Year</th>\n",
" <th>reviewText</th>\n",
" <th>Month</th>\n",
" <th>Week</th>\n",
" <th>Day</th>\n",
" <th>Dayofweek</th>\n",
" <th>Dayofyear</th>\n",
" <th>Is_month_end</th>\n",
" <th>Is_month_start</th>\n",
" <th>Is_quarter_end</th>\n",
" <th>Is_quarter_start</th>\n",
" <th>Is_year_end</th>\n",
" <th>Is_year_start</th>\n",
" <th>secondsSpentOnReview</th>\n",
" <th>amountSpentOnAmazon</th>\n",
" <th>helpful</th>\n",
" <th>overall</th>\n",
" </tr>\n",
" <tr>\n",
" <th>date</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2014-02-28</th>\n",
" <td>A2IBPI20UZIR0U</td>\n",
" <td>1384719342</td>\n",
" <td>2014</td>\n",
" <td>Not much to write about here, but it does exac...</td>\n",
" <td>2</td>\n",
" <td>9</td>\n",
" <td>28</td>\n",
" <td>4</td>\n",
" <td>59</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>28.0</td>\n",
" <td>1080.0</td>\n",
" <td>[0, 0]</td>\n",
" <td>5.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2013-03-16</th>\n",
" <td>A14VAT5EAX3D9S</td>\n",
" <td>1384719342</td>\n",
" <td>2013</td>\n",
" <td>The product does exactly as it should and is q...</td>\n",
" <td>3</td>\n",
" <td>11</td>\n",
" <td>16</td>\n",
" <td>5</td>\n",
" <td>75</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>124.0</td>\n",
" <td>1223.0</td>\n",
" <td>[13, 14]</td>\n",
" <td>5.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" reviewerID asin Year \\\n",
"date \n",
"2014-02-28 A2IBPI20UZIR0U 1384719342 2014 \n",
"2013-03-16 A14VAT5EAX3D9S 1384719342 2013 \n",
"\n",
" reviewText Month Week Day \\\n",
"date \n",
"2014-02-28 Not much to write about here, but it does exac... 2 9 28 \n",
"2013-03-16 The product does exactly as it should and is q... 3 11 16 \n",
"\n",
" Dayofweek Dayofyear Is_month_end Is_month_start Is_quarter_end \\\n",
"date \n",
"2014-02-28 4 59 True False False \n",
"2013-03-16 5 75 False False False \n",
"\n",
" Is_quarter_start Is_year_end Is_year_start secondsSpentOnReview \\\n",
"date \n",
"2014-02-28 False False False 28.0 \n",
"2013-03-16 False False False 124.0 \n",
"\n",
" amountSpentOnAmazon helpful overall \n",
"date \n",
"2014-02-28 1080.0 [0, 0] 5.0 \n",
"2013-03-16 1223.0 [13, 14] 5.0 "
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"joined_samp.head(2)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"cat_sz = [(c, len(joined_samp[c].cat.categories)+1) for c in cat_vars]\n",
"emb_szs = [(c, min(50, (c+1)//2)) for _,c in cat_sz]"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"df, y, nas, mapper = proc_df(joined_samp, 'overall', do_scale=True, skip_flds=emb_vars)\n",
"yl = np.log(y)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"train_ratio = 0.75\n",
"train_size = int(samp_size * train_ratio); train_size\n",
"val_idx = list(range(train_size, len(df)))"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"val_idx = np.flatnonzero(\n",
" (df.index<=datetime.datetime(2014,9,17)) & (df.index>=datetime.datetime(2014,8,1)))"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"def inv_y(a): return np.exp(a)\n",
"\n",
"def exp_rmspe(y_pred, targ):\n",
" targ = inv_y(targ)\n",
" pct_var = (targ - inv_y(y_pred))/targ\n",
" return math.sqrt((pct_var**2).mean())\n",
"\n",
"max_log_y = np.max(yl)\n",
"y_range = (0, max_log_y*1.2)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"md = ColumnarModelData.from_data_frame(PATH, val_idx, df, yl, cat_flds=cat_vars, bs=128)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"m = md.get_learner(emb_szs, len(df.columns)-len(cat_vars),\n",
" 0.04, 1, [1000,500], [0.001,0.01], y_range=y_range)\n",
"lr = 1e-3"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "726a84008ff94bae8197c7fb05c06fd7",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
"<p>\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
" Widgets Documentation</a> for setup instructions.\n",
"</p>\n",
"<p>\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"</p>\n"
],
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ubuntu/src/anaconda3/envs/fastai/lib/python3.6/site-packages/numpy/core/fromnumeric.py:2957: RuntimeWarning: Mean of empty slice.\n",
" out=out, **kwargs)\n",
"/home/ubuntu/src/anaconda3/envs/fastai/lib/python3.6/site-packages/numpy/core/_methods.py:80: RuntimeWarning: invalid value encountered in double_scalars\n",
" ret = ret.dtype.type(ret / rcount)\n"
]
},
{
"ename": "ValueError",
"evalue": "need at least one array to stack",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-45-8857f3701be9>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mexp_rmspe\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/fastai/fastai/learner.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, lrs, n_cycle, wds, **kwargs)\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msched\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[0mlayer_opt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_layer_opt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlrs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 215\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_gen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayer_opt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_cycle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 216\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 217\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwarm_up\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwds\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/fastai/fastai/learner.py\u001b[0m in \u001b[0;36mfit_gen\u001b[0;34m(self, model, data, layer_opt, n_cycle, cycle_len, cycle_mult, cycle_save_name, best_save_name, use_clr, metrics, callbacks, use_wd_sched, norm_wds, wds_sched_mult, **kwargs)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0mn_epoch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msum_geom\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcycle_len\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcycle_len\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcycle_mult\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_cycle\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 161\u001b[0m return fit(model, data, n_epoch, layer_opt.opt, self.crit,\n\u001b[0;32m--> 162\u001b[0;31m metrics=metrics, callbacks=callbacks, reg_fn=self.reg_fn, clip=self.clip, **kwargs)\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_layer_groups\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_layer_groups\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/fastai/fastai/model.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(model, data, epochs, opt, crit, metrics, callbacks, stepper, **kwargs)\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 106\u001b[0;31m \u001b[0mvals\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalidate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstepper\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mval_dl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 107\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlayout\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mnames\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[0mprint_stats\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mdebias_loss\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/fastai/fastai/model.py\u001b[0m in \u001b[0;36mvalidate\u001b[0;34m(stepper, dl, metrics)\u001b[0m\n\u001b[1;32m 127\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mto_np\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 128\u001b[0m \u001b[0mres\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpreds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mf\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 129\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 130\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_prediction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/src/anaconda3/envs/fastai/lib/python3.6/site-packages/numpy/core/shape_base.py\u001b[0m in \u001b[0;36mstack\u001b[0;34m(arrays, axis, out)\u001b[0m\n\u001b[1;32m 347\u001b[0m \u001b[0marrays\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0masanyarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0marr\u001b[0m \u001b[0;32min\u001b[0m \u001b[0marrays\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 348\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0marrays\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 349\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'need at least one array to stack'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 350\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 351\u001b[0m \u001b[0mshapes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0marr\u001b[0m \u001b[0;32min\u001b[0m \u001b[0marrays\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: need at least one array to stack"
]
}
],
"source": [
"m.fit(lr, 3, metrics=[exp_rmspe])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment