Skip to content

Instantly share code, notes, and snippets.

Created March 10, 2018 17:17
Show Gist options
  • Save anonymous/0dd0df21cf404cf2bb51d0148c8b7d8b to your computer and use it in GitHub Desktop.
Save anonymous/0dd0df21cf404cf2bb51d0148c8b7d8b to your computer and use it in GitHub Desktop.
fastai.text imdb example
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "## IMDb"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "from fastai.text import *\nimport html",
"execution_count": 1,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "BOS = 'xbos' # beginning-of-sentence tag\nFLD = 'xfld' # data field tag\n\nPATH=Path('data/aclImdb/')",
"execution_count": 2,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Standardize format"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "CLASSES = ['neg', 'pos']\n\ndef get_texts(path):\n texts,labels = [],[]\n for idx,label in enumerate(CLASSES):\n for fname in (path/label).glob('*.*'):\n texts.append(fname.open('r').read())\n labels.append(idx)\n return texts,labels\n\ntrn_texts,trn_labels = get_texts(PATH/'train')\nval_texts,val_labels = get_texts(PATH/'test')",
"execution_count": 3,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "len(trn_texts),len(val_texts)",
"execution_count": 4,
"outputs": [
{
"data": {
"text/plain": "(25000, 25000)"
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "col_names = ['labels','text']",
"execution_count": 5,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "CLAS_PATH=Path('data/imdb_clas/')\nCLAS_PATH.mkdir(exist_ok=True)",
"execution_count": 6,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "df_trn = pd.DataFrame({'text':trn_texts, 'labels':trn_labels}, columns=col_names)\ndf_val = pd.DataFrame({'text':val_texts, 'labels':val_labels}, columns=col_names)\n\ndf_trn.to_csv(CLAS_PATH/'train.csv', header=False, index=False)\ndf_val.to_csv(CLAS_PATH/'test.csv', header=False, index=False)\n\n(CLAS_PATH/'classes.txt').open('w').writelines(f'{o}\\n' for o in CLASSES)",
"execution_count": 8,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def get_texts(path):\n return [fname.open('r').read() for fname in (path/'all').glob('*.*')]\n\nall_texts = get_texts(PATH/'train')\nall_texts += get_texts(PATH/'test')",
"execution_count": 32,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "trn_texts,val_texts = sklearn.model_selection.train_test_split(all_texts, test_size=0.1)",
"execution_count": 33,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "len(trn_texts), len(val_texts)",
"execution_count": 34,
"outputs": [
{
"data": {
"text/plain": "(90000, 10000)"
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "LM_PATH=Path('data/imdb_lm/')\nLM_PATH.mkdir(exist_ok=True)",
"execution_count": 7,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "df_trn = pd.DataFrame({'text':trn_texts, 'labels':[0]*len(trn_texts)}, columns=col_names)\ndf_val = pd.DataFrame({'text':val_texts, 'labels':[0]*len(val_texts)}, columns=col_names)\n\ndf_trn.to_csv(LM_PATH/'train.csv', header=False, index=False)\ndf_val.to_csv(LM_PATH/'test.csv', header=False, index=False)",
"execution_count": 35,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Language model tokens"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "chunksize=24000",
"execution_count": 36,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "re1 = re.compile(r' +')\n\ndef fixup(x):\n x = x.replace('#39;', \"'\").replace('amp;', '&').replace('#146;', \"'\").replace(\n 'nbsp;', ' ').replace('#36;', '$').replace('\\\\n', \"\\n\").replace('quot;', \"'\").replace(\n '<br />', \"\\n\").replace('\\\\\"', '\"').replace('<unk>','u_n').replace(' @.@ ','.').replace(\n ' @-@ ','-').replace('\\\\', ' \\\\ ')\n return re1.sub(' ', html.unescape(x))",
"execution_count": 37,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def get_texts(df, n_lbls=1):\n labels = df.iloc[:,range(n_lbls)].values.astype(np.int64)\n texts = f'\\n{BOS} {FLD} 1 ' + df[n_lbls].astype(str)\n for i in range(n_lbls+1, len(df.columns)): texts += f' {FLD} {i-n_lbls} ' + df[i].astype(str)\n texts = texts.apply(fixup).values.astype(str)\n\n tok = Tokenizer().proc_all_mp(partition_by_cores(texts))\n return tok, list(labels)",
"execution_count": 38,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def get_all(df, n_lbls):\n tok, labels = [], []\n for i, r in enumerate(df):\n print(i)\n tok_, labels_ = get_texts(r, n_lbls)\n tok += tok_;\n labels += labels_\n return tok, labels",
"execution_count": 39,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "df_trn = pd.read_csv(LM_PATH/'train.csv', header=None, chunksize=chunksize)\ndf_val = pd.read_csv(LM_PATH/'test.csv', header=None, chunksize=chunksize)",
"execution_count": 40,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "tok_trn, trn_labels = get_all(df_trn, 1)\ntok_val, val_labels = get_all(df_val, 1)",
"execution_count": 41,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "0\n1\n2\n3\n0\n"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "(LM_PATH/'tmp').mkdir(exist_ok=True)",
"execution_count": 42,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "np.save(LM_PATH/'tmp'/'tok_trn.npy', tok_trn)\nnp.save(LM_PATH/'tmp'/'tok_val.npy', tok_val)",
"execution_count": 43,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# trn_joined = [' '.join(o) for o in tok_trn]\n# mdl_fn = f'{PATH}tmp/{pr_abbr}_joined.txt'\n# open(mdl_fn, 'w', encoding='utf-8').writelines(trn_joined)",
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "tok_trn = np.load(LM_PATH/'tmp'/'tok_trn.npy')\ntok_val = np.load(LM_PATH/'tmp'/'tok_val.npy')",
"execution_count": null,
"outputs": []
},
{
"metadata": {
"scrolled": true,
"trusted": true
},
"cell_type": "code",
"source": "freq = Counter(p for o in tok_trn for p in o)\nfreq.most_common(25)",
"execution_count": 44,
"outputs": [
{
"data": {
"text/plain": "[('the', 1210343),\n ('.', 992949),\n (',', 986569),\n ('and', 587674),\n ('a', 583580),\n ('of', 525539),\n ('to', 485458),\n ('is', 393569),\n ('it', 341409),\n ('in', 337946),\n ('i', 308226),\n ('this', 270821),\n ('that', 261198),\n ('\"', 237159),\n (\"'s\", 221032),\n ('-', 188062),\n ('was', 180556),\n ('\\n\\n', 179179),\n ('as', 165827),\n ('with', 159218),\n ('for', 158785),\n ('movie', 157922),\n ('but', 150381),\n ('film', 144309),\n ('you', 124324)]"
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "max_vocab = 60000\nmin_freq = 2",
"execution_count": 48,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "itos = [o for o,c in freq.most_common(max_vocab) if c>min_freq]\nitos.insert(0, '_pad_')\nitos.insert(0, '_unk_')\nstoi = collections.defaultdict(lambda:0, {v:k for k,v in enumerate(itos)})\nlen(itos)",
"execution_count": 49,
"outputs": [
{
"data": {
"text/plain": "60002"
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "trn_lm = np.array([[stoi[o] for o in p] for p in tok_trn])\nval_lm = np.array([[stoi[o] for o in p] for p in tok_val])",
"execution_count": 50,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "np.save(LM_PATH/'tmp'/'trn_ids.npy', trn_lm)\nnp.save(LM_PATH/'tmp'/'val_ids.npy', val_lm)\npickle.dump(itos, open(LM_PATH/'tmp'/'itos.pkl', 'wb'))",
"execution_count": 54,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Language model"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "wd=1e-7\nbptt=70\nbs=52\nem_sz,nh,nl = 400,1150,3\nopt_fn = partial(optim.Adam, betas=(0.8, 0.99))",
"execution_count": 8,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "trn_lm = np.load(LM_PATH/'tmp'/'trn_ids.npy')\nval_lm = np.load(LM_PATH/'tmp'/'val_ids.npy')\ntrn_lm = np.concatenate(trn_lm)\nval_lm = np.concatenate(val_lm)\n\nitos = pickle.load(open(LM_PATH/'tmp'/'itos.pkl', 'rb'))\nvs = len(itos)\n\ntrn_dl = LanguageModelLoader(trn_lm, bs, bptt)\nval_dl = LanguageModelLoader(val_lm, bs, bptt)\nmd = LanguageModelData(PATH, 1, vs, trn_dl, val_dl, bs=bs, bptt=bptt)",
"execution_count": 9,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "drops = np.array([0.25, 0.1, 0.2, 0.02, 0.15])*1.",
"execution_count": 10,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learner= md.get_model(opt_fn, em_sz, nh, nl, \n dropouti=drops[0], dropout=drops[1], wdrop=drops[2], dropoute=drops[3], dropouth=drops[4])\n\nlearner.metrics = [accuracy]\nlearner.unfreeze()",
"execution_count": 15,
"outputs": []
},
{
"metadata": {
"collapsed": true,
"trusted": true
},
"cell_type": "code",
"source": "learner.lr_find(start_lr=1e-6,end_lr=1e12)",
"execution_count": 13,
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "216b7053bfb547029f09f8d6bb8f0353",
"version_major": 2,
"version_minor": 0
},
"text/plain": "A Jupyter Widget"
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": " 29%|██▊ | 1967/6872 [06:00<14:59, 5.45it/s, loss=7.06]"
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-13-7ca087bb6cc1>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mlearner\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlr_find\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstart_lr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1e-6\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mend_lr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1e12\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/data1/jhoward/git/fastai/courses/dl2/fastai/learner.py\u001b[0m in \u001b[0;36mlr_find\u001b[0;34m(self, start_lr, end_lr, wds, linear)\u001b[0m\n\u001b[1;32m 256\u001b[0m \u001b[0mlayer_opt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_layer_opt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstart_lr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 257\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msched\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mLR_Finder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlayer_opt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrn_dl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend_lr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlinear\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 258\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_gen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayer_opt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 259\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'tmp'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 260\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/data1/jhoward/git/fastai/courses/dl2/fastai/learner.py\u001b[0m in \u001b[0;36mfit_gen\u001b[0;34m(self, model, data, layer_opt, n_cycle, cycle_len, cycle_mult, cycle_save_name, best_save_name, use_clr, metrics, callbacks, use_wd_sched, norm_wds, wds_sched_mult, **kwargs)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0mn_epoch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msum_geom\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcycle_len\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcycle_len\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcycle_mult\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_cycle\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 161\u001b[0m return fit(model, data, n_epoch, layer_opt.opt, self.crit,\n\u001b[0;32m--> 162\u001b[0;31m metrics=metrics, callbacks=callbacks, reg_fn=self.reg_fn, clip=self.clip, **kwargs)\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_layer_groups\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_layer_groups\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/data1/jhoward/git/fastai/courses/dl2/fastai/model.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(model, data, epochs, opt, crit, metrics, callbacks, stepper, **kwargs)\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0mbatch_num\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcb\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mcb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 96\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstepper\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mV\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mV\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 97\u001b[0m \u001b[0mavg_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mavg_loss\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mavg_mom\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mavg_mom\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0mdebias_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mavg_loss\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mavg_mom\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mbatch_num\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/data1/jhoward/git/fastai/courses/dl2/fastai/model.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, xs, y)\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclip\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# Gradient clipping\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclip_grad_norm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainable_params_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclip\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 48\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 49\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mraw_loss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/optim/adam.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[0;31m# Decay the first and second moment running average coefficient\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[0mexp_avg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmul_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbeta1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mbeta1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 70\u001b[0;31m \u001b[0mexp_avg_sq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmul_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbeta2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maddcmul_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mbeta2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 71\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0mdenom\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mexp_avg_sq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgroup\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'eps'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
]
},
{
"metadata": {
"collapsed": true,
"trusted": true
},
"cell_type": "code",
"source": "learner.sched.plot()",
"execution_count": 14,
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEOCAYAAACEiBAqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAIABJREFUeJzt3XecVOXZ//HPNduXbbAsvSwdEUV0QbGB3fhoLNGoKWKJmmiKKab/8iR5kmieJE9iTIzBmhg1iRqNvWEENdKbIB2kl11YYBvb5vr9MYNZcZFd2Nkz5ft+veY1Z86cmXPdOzDfuU+5j7k7IiKSukJBFyAiIsFSEIiIpDgFgYhIilMQiIikOAWBiEiKUxCIiKQ4BYGISIpTEIiIpDgFgYhIilMQiIikuPSgC2iL7t27e2lpadBliIgklLlz51a4e8nBlkuIICgtLWXOnDlBlyEiklDMbF1bltOmIRGRFKcgEBFJcQoCEZEUpyAQEUlxMQsCM7vfzLab2eIW8y4zsyVmFjazslitW0RE2i6WPYIHgXP3m7cYuASYHsP1iohIO8Ts8FF3n25mpfvNWwpgZrFa7Qcs3bKHjZV1hAxCISNkRpoZIYvUkBay959LD0UeZ6SFSDvI433zRESSQUKcR3CoHp65jr/MWB+T985MC5GVESInI43sjLTofYisFtP7nsvOSKMgJ4PueZkUd8miW5dMuudl0r9bLtkZaTGpT0SkreI2CMzsBuAGgAEDBhzSe9w0aShXjBtAc9gJ+74bhMNOszvu0BydDoedxmanOew0hcOR+2anKew0h8M0RR83hsM0Njl1jc3sbXGLPA5T19jMrtqG96f3PVdd34T7B+sLGQws7sKoPgWcNqIHZx3Rk8LcjMP904mItEvcBoG7TwGmAJSVlflBFm9Vn6Ic+hTldGhdh6qpOUxlbSM7aurZUd1ARXU9q8trWLmtitlrd/Lcoi3kZaVzy5nDuO7kQZ22+UxEJG6DINmkp4Uoyc+iJD/rQ8+Fw86iTbu5c+pKfvLcUqrrm7jlzOEBVCkiqSiWh48+CrwNjDCzjWZ2nZldbGYbgQnAc2b2UqzWn0hCIeOY/kXcO7mMi8f25c7XVrGxsjboskQkRcQsCNz9Snfv7e4Z7t7P3e9z9yej01nu3tPdz4nV+hORmXHrOSMAeGhGm8aKEhE5bDqzOM70Kcrh5KHdeXHxVnz/vcsiIjGgIIhDZx7Rg3U7alldXhN0KSKSAhQEcei0kT0AeG3ZtoArEZFUoCCIQ/265jKyVz6vLt0edCkikgIUBHHqjCN6MHddJbvrGoMuRUSSnIIgTk0c3oPmsPP26oqgSxGRJKcgiFNjBxTRJTONN1YqCEQkthQEcSojLcSEIcUKAhGJOQVBHDtlWAnrd9aybocOIxWR2FEQxLFThnUHUK9ARGJKQRDHBnXvQt+iHKatKA+6FBFJYgqCOGZmTBxRwr9XVdDQFA66HBFJUgqCODdpeAk1Dc3MWbcz6FJEJEkpCOLciUO7k5Fm2jwkIjGjIIhzeVnplA3sxrTlCgIRiQ0FQQKYNKKEZVur2LK7LuhSRCQJKQgSwMQRJQBM1+YhEYkBBUECGNEzn14F2byuzUMiEgMKggRgZkwaUcKbKytobNZhpCLSsRQECWLi8BKq6puYv35X0KWISJJRECSIk4Z1Jz1kvL5cF6sRkY6lIEgQBdkZHDuwq84nEJEOpyBIIBOHl7Bk8x62V+0NuhQRSSIKggQyKXoYqU4uE5GOpCBIIKN6F9CrIJtXl24LuhQRSSIKggRiZpw1qifTVpRT19AcdDkikiQUBAnmnCN7sbcxrJ3GItJhYhYEZna/mW03s8Ut5nUzs1fMbGX0vmus1p+sjh/cjcKcDF5esjXoUkQkScSyR/AgcO5+874NTHX3YcDU6GNph4y0EGeN6skr725jb6M2D4nI4YtZELj7dGD/q6lcCPwpOv0n4KJYrT+ZXTCmD1X1TRp7SEQ6RGfvI+jp7lsAovc9Onn9SeGkIcUUd8nkmUWbgy5FRJJA3O4sNrMbzGyOmc0pL9cv35bS00Kcd1Rvpi7dRk19U9DliEiC6+wg2GZmvQGi9wccOMfdp7h7mbuXlZSUdFqBieKCMX3Y2xjWOQUictg6OwieBiZHpycD/+zk9SeNsoFd6V2YzTMLtXlIRA5PLA8ffRR4GxhhZhvN7DrgduAsM1sJnBV9LIcgFDLOP7o301aUs6u2IehyRCSBxfKooSvdvbe7Z7h7P3e/z913uPsZ7j4ser//UUXSDhce05fGZudp9QpE5DDE7c5iObjRfQsZ3beAR2aux92DLkdEEpSCIMFdMW4Ay7ZWsXDj7qBLEZEEpSBIcBce04ecjDT+Omt90KWISIJSECS4/OwMLhjTm6cXbqZa5xSIyCFQECSBK8YPoLahmacXaKexiLSfgiAJjO1fxIie+fx1tjYPiUj7KQiSgJlx5fj+LNq4m8WbtNNYRNpHQZAkLh7bj6z0kHoFItJuCoIkUZibwXlH9eaf8zdT26CdxiLSdgqCJHLl+AFU1Tfx3KItQZciIglEQZBExpV2ZUhJFx6asU5nGotImykIkoiZce3Jg1i0cTdvr9kRdDkikiAUBEnmE8f2o3teFn+ctiboUkQkQSgIkkx2RhrXnFTKtBXlLN2yJ+hyRCQBKAiS0GeOH0huZhr3TFevQEQOTkGQhApzM7hy/ACeXriZTbvqgi5HROKcgiBJXXvyIBy4/821QZciInFOQZCk+hbl8PExfXh01np21zYGXY6IxDEFQRK74dTB1DY085eZ64IuRUTimIIgiR3Ru4CJw0t44K217G1sDrocEYlTCoIkd+PEwVRUN/CPeZuCLkVE4pSCIMlNGFzM0f0KufeNNTSHNeyEiHyYgiDJmRk3njqENRU1vPLutqDLEZE4pCBIAeeO7sWAbrncPW21BqMTkQ9REKSAtJBx/SmDWLBhF7Pfqwy6HBGJMwqCFHHpcf3p1iWTP05bHXQpIhJnFAQpIiczjckTSpm6bDsrtlUFXY6IxJFAgsDMvmJmi81siZndEkQNqeiqCQPJyUhjigajE5EWOj0IzGw0cD0wHhgDnG9mwzq7jlTUtUsml4/rzz8XbGLr7r1BlyMicSKIHsERwAx3r3X3JmAacHEAdaSk604eRNjhrtdXBV2KiMSJIIJgMXCqmRWbWS5wHtA/gDpSUv9uuVw5vj8Pz1zPqu3VQZcjInGg04PA3ZcCPwdeAV4EFgJN+y9nZjeY2Rwzm1NeXt7JVSa3r545nMy0kHoFIgIEtLPY3e9z92Pd/VRgJ7CylWWmuHuZu5eVlJR0fpFJrDgviyvHD+CfCzazYWdt0OWISMCCOmqoR/R+AHAJ8GgQdaSy608dRMjQEUQiEth5BE+Y2bvAM8DN7q7TXTtZ78IcLhnbj7/P2UB5VX3Q5YhIgILaNHSKu49y9zHuPjWIGiQyRHVjc5g/vK6zjUVSmc4sTmGDS/K47Lj+/GXGOjZWal+BSKpSEKS4r5w5DAzuePVD++tFJEUoCFJcn6IcrjphIE/M28hKjUEkkpIUBMJNpw0lNzOdX728IuhSRCQACgKhW5dMrj9lMC8u2co7G3cHXY6IdDIFgQBw7cmlFOZkcMdU9QpEUo2CQADIz87g+lMG8erS7eoViKQYBYG8b/KJ6hWIpCIFgbxPvQKR1KQgkA9Qr0Ak9SgI5APUKxBJPQoC+ZDJJ5ZSlJvBb15Vr0AkFSgI5EMivYLBTF22nUUbdwVdjojEmIJAWnXVhIEU5WZoDCKRFNCmIDCzr5hZgUXcZ2bzzOzsWBcnwVGvQCR1tLVHcK277wHOBkqAa4DbY1aVxIV9vYLfqFcgktTaGgQWvT8PeMDdF7aYJ0lqX6/gtWXbWbhBvQKRZNXWIJhrZi8TCYKXzCwfCMeuLIkX+44gumOqegUiyaqtQXAd8G1gnLvXAhlENg9JksvLSlevQCTJtTUIJgDL3X2XmX0G+D6gs41SxL5ewW/VKxBJSm0Ngj8AtWY2BvgmsA74c8yqkriSl5XOtScNYuqy7by7eU/Q5YhIB2trEDS5uwMXAne4+x1AfuzKkngzeUIpeVnp/O5f6hWIJJu2BkGVmX0H+CzwnJmlEdlPICmiMDeDa08q5fl3trJks7YKiiSTtgbB5UA9kfMJtgJ9gV/ErCqJS9edMpiC7HR+/Yp6BSLJpE1BEP3yfxgoNLPzgb3urn0EKaYwJ3JewatLt+kIIpEk0tYhJj4JzAIuAz4JzDSzS2NZmMSna04eRNfcDH71ikYmFUkWbd009D0i5xBMdvergPHA/4tdWRKv8rLSuXHiEKavKGfOezuDLkdEOkBbgyDk7ttbPN7RjtdKkrlqwkC652Xyq5fVKxBJBm39Mn/RzF4ys6vN7GrgOeD5Q12pmX3VzJaY2WIze9TMsg/1vaTz5Wamc9Okoby9Zgf/Xl0RdDkicpjaurP4VmAKcDQwBpji7t86lBWaWV/gy0CZu48G0oArDuW9JDifOn4AvQqy+b+XVxA5xUREElWbN++4+xPu/jV3/6q7P3mY600HcswsHcgFNh/m+0kny85I4+bThzJnXSXTVpQHXY6IHIaPDAIzqzKzPa3cqszskMYacPdNwC+B9cAWYLe7v3wo7yXBurysP/265vCLl5YTDqtXIJKoPjII3D3f3QtaueW7e8GhrNDMuhIZqmIQ0AfoEh3Ibv/lbjCzOWY2p7xcvzjjUWZ6iG+cPYIlm/fwzCJ16kQSVRBH/pwJrHX3cndvBP4BnLj/Qu4+xd3L3L2spKSk04uUtvn4mD6M6l3AL19eTkOTLlEhkoiCCIL1wAlmlmtmBpwBLA2gDukAoZDxzXNHsGFnHY/OWh90OSJyCDo9CNx9JvA4MA94J1rDlM6uQzrOxOElnDC4G3e+tpKa+qagyxGRdgrkpDB3/293H+nuo939s+5eH0Qd0jHMjG+dO5KK6gbufWNt0OWISDvp7GDpEGMHdOXcI3txzxtr2FGtXBdJJAoC6TDfOGc4tQ1N/P5fq4MuRUTaQUEgHWZoj3wuO64/f5mxjo2VtUGXIyJtpCCQDnXLWcPA0MVrRBKIgkA6VO/CHK4+sZR/zN/I8q1VQZcjIm2gIJAOd9OkIeRlpfOLl5YFXYqItIGCQDpcUW4mn584hFeXbme2Ll4jEvcUBBIT15xUSkl+Fj9/YZmGqRaJcwoCiYnczHRuOXMYc9ZV8uLirUGXIyIfQUEgMXN5WX9G9srnp88vZW9jc9DliMgBKAgkZtLTQvzg/FFsrKzjvjc19IRIvFIQSEydOLQ7Z4/qye//tYpte/YGXY6ItEJBIDH3vf86gqZm5/YXdDipSDxSEEjMDSzuwucnDubJ+Zt4ffn2oMsRkf0oCKRT3Hz6UIb2yON7Ty6mWtcsEIkrCgLpFFnpafz8E0ezeXcdv3hRm4hE4omCQDrNcQO7cvWJpfx5xjqdcSwSRxQE0qm+cfYI+hbl8K0nFuncApE4oSCQTtUlK52fXXwUa8pr+O1UDVUtEg8UBNLpTh1ewmXH9ePuaat5a1VF0OWIpDwFgQTihx8/ksEleXz50fls3a0TzUSCpCCQQHTJSufuzxxLXWMzNz8yj4amcNAliaQsBYEEZmiPfH7+iaOZu66S7z75DuGwhqsWCUJ60AVIartgTB9Wba/mjqkrKcjO4AcXjAq6JJGUoyCQwN1y5jB21zVy/1trGdQ9l89OKA26JJGUok1DEjgz4/+dP4rTRpTwo2feZe46nWwm0pkUBBIX0kLGHVeOpXdRNrf8bQG7ahuCLkkkZXT6piEzGwH8rcWswcAP3P03nV2LxJeC7Ax+c/lYrpwyg+v/PIeHrjue7Iy0oMsSaZcNO2v53WurWLp1D5t3RQ6NDhlkpoeorGkg7PDJsn7cfPpQNlXWsWlXHZt31bGmvIZV26sJmTGiVz7nju5Fn6IcBnXvEvOaLcgLi5tZGrAJON7d1x1oubKyMp8zZ07nFSaBenbRZr74yHzOObInv//UsaSnqeMq8a857Dw2ZwP/+9JyauqbGD+oG32LcgiFjOZmp6q+kV4FOVTWNvDUgk3s/9XbrUsm/bvlAvDOxl3sO4jurk8fy3lH9T6kmsxsrruXHWy5oHcWnwGs/qgQkNRz/tF9KK+q50fPvMv3n1rMbZcchZkFXZbIAbk733x8EU/M28jQHnk8cv3xjOxVcMDlrzt5EK8v307frjkMLclnQHEuhTkZ7z9ftbeRuesqWbW9mtNH9oh5/UEHwRXAowHXIHHompMGUVFdz+//tZqczDS+de5IbSaSuLR9z15ue2EZT87fxE2ThvC1s4YftBc7um8ho/sWHvD5/OwMJo3owaQRsQ8BCDAIzCwT+DjwnQM8fwNwA8CAAQM6sTKJF18/awR76pp44K33eGnxVj4zYSCTJ5TSJSvo3y+Sqtyd9TtrWbplD2+v3sGaihpmrd1JQ3OYr545nC+dPpRQKPF6r4HtIzCzC4Gb3f3sgy2rfQSp7d+rK7jj1ZXMXLuTXgXZfH7iYC4r669AkE6xYlsVT8zbyJPzNrG9qv79+V0y0xhckseo3gVce/IgRvTKD7DK1rV1H0GQQfBX4CV3f+BgyyoIBGDuukp+9vxS5q6rpCA7nWtOGsQNpw5WIEhMTF9Rzv+9soIFG3aRHjJOGFzM0B55lORnMap3AROGFMf95sq4DgIzywU2AIPdfffBllcQSEtz11Vyz/Q1vLhkK3lZ6dx46mAmDCmmS1Y6I3vla8eyHJbG5jA/fuZdHpqxjoHFuVx0TF+umjCQ4rysoEtrt7gOgvZSEEhr5q+v5A+vr+bld7e9P69rbgbHDezGeUf1YuLwkoT8zyvBCYedmx+ZxwuLt3LdyYO49ZwRcf+r/6MkyuGjIods7ICuTLmqjLUVNbxXUUN5dT2z1+7kzVUVvLp0GxlpxjlH9uLYAV3ZXddIafdcxvbvSmZ6iO55WdQ1NlO1t5FdtY1kpYcozsuiKCfjsHb2Vdc34e7srGmgR342OZmJ+yWSapqaw3z7H+/wwuKt3HrOCG4+bWjQJXUaBYEkvEHdu7x/9uUny/rT2BxmwYZdPP/OFv4xbxPPLtryodeY8aETeiAy1MWwHnmcMLiYkb3yCYWM9TtqaXbHoq9bW1FDeVU9aytq6FOUwzH9iyjKzeTlJVtZtrXq/ffKSg9RVtqVo/sVMXF4CUf1LdT+jDh2+wvLeHzuRm6aNISbJg0JupxOpU1DktQamsJU1zeRl5XOim1VLNq4m/qmZnbWNFCQnUFBTjpFuZnUN4XZUV1PeVU9izbuZubaHTQ2R/5vpIWMfX2EprBTWpxLj/xsBhTnsnlXHQs27KK2oZlxpV05eWgJze7075rD0i1VvLGynLUVNTSFncy0EBOGFHPmqJ6cc2RPeuRnB/eHkQ/49+oKPnXPTCZPGMiPLhwddDkdRvsIRA5DbUMTO6ojA9+V5Ge9v524Oeyk7bfpqKk5TE1D8wfODG2pam8jb66sYN76Sl55dxvv7ailS2YaP75wNBeP7ZuQx50nk3DY+fjv36SyppGpX5+Y0PsE9qcgEIlD7s6KbdV898l3mLuukiN6F3D3Z45lYHHsBxaT1j01fxO3/G0Bv758DBeP7Rd0OR2qrUGg0bxEOpFFR5b8+40TuOOKY9iyu45P/OFt5q2vDLq0lOTu3PnaSkb1LuDCMX2DLicwCgKRAKSFjAuP6ctjN06gS1Yan713JrPW6oI8nW32e5WsLq/h6pNKU3oTnYJAJEDDeubz2I0T6FmYzafvncETczcGXVJK+eus9eRnpXP+0Yc2zHOyUBCIBKxHQTZPfuEkxpV24+uPLeTXr6wgHI7/fXeJrr6pmZeWbOW/ju5NbmZqH9arIBCJA4W5GTx4zXguObYvd0xdydUPzqaxORx0WUlt9tpKahqaOWtUz6BLCZyCQCROZKaH+NVlY/jxhUcyfUU5P31uadAlJbWpy7aRlR7ixCHdgy4lcKndHxKJM2bGVRNKWVtRwwNvvUe/rjl87pTBQZeVdNyd15Zt58QhxRoGBPUIROLS9/9rFB8b3YufPLeUfy7YFHQ5SWdNRQ3rdtRy+hHaLAQKApG4lBYyfn35MYwf1I1vPLaQt1ZVBF1SUpkdPVT3pCHFAVcSHxQEInEqOyONe64qY3D3PG58aC6LNx300h3SRvPX76IoN+P9wQpTnYJAJI4V5mTwp2vHU5CdzlX3z2J5i9FN5dDN31DJ2P5FuohRlIJAJM71Kszm4etPICPNmHz/LCprGoIuKaHt2dvIyu3VjB3QNehS4oaCQCQBDOrehfsmj2NHTT23Pr6IRBgsMl4t2rAbdxg7oCjoUuKGgkAkQYzuW8i3P3YEry7dxr1vrA26nIS1bOseAI7sUxhwJfFDQSCSQK49qZSPje7F7S8u087jQ7RiWxXd8zLp1iUz6FLihoJAJIGYGbdfcjTFXTK58aG5lFfVB11Swlm5vZphPfKDLiOuKAhEEkxhbgb3XFVGRXU91/1pNnv2NgZdUsJwd1Ztq2ZYz7ygS4krCgKRBDSmfxF3ffpY3t28h6vvn0V1fVPQJSWErXv2UlXfxLAeCoKWFAQiCeqMI3py55VjWbhxN5fd/Tabd9UFXVLcW7mtGohcB0L+Q0EgksA+dlRvpnz2ODburOUz981kd602E32UldujQaAewQcoCEQS3BlH9OS+q8exYWctX3x0HvVNzUGXFLdWbquiuEsmxXlZQZcSVxQEIklg/KBu/Ozio3hjZQUn3vYaL7yzJeiS4tLK7dUMVW/gQwIJAjMrMrPHzWyZmS01swlB1CGSTC4r68+D14yjOC+TLzw8j+v/PIdV2zU20T7uzioFQauCujDNHcCL7n6pmWUCuQHVIZJUJo3owUlDu3PvG2v53WsrOe+35VxzUilXjBuQ8iNtVtY2sruukcElCoL9dXqPwMwKgFOB+wDcvcHdd3V2HSLJKiMtxBcmDWHaN0/jzCN6cM/0NZzz6+nc+8Ya9jam7v6DtRWRHcWDUzwQWxPEpqHBQDnwgJnNN7N7zUyfjEgH656XxV2fPo63v3MG4wZ15SfPLeXc30xn+oryoEsLxJryGoCU7xm1JoggSAeOBf7g7mOBGuDb+y9kZjeY2Rwzm1Nenpr/cEU6Qs+CbB7+3Ak8dN34yDWR75/FdQ/OZsnm1BqraG1FDekho1/XnKBLiTtBBMFGYKO7z4w+fpxIMHyAu09x9zJ3LyspKenUAkWS0SnDSnjhK6fw9bOGM+u9nVxw55t87k+z+fPb71HXkPybjN7bUcOAbrmkp+lgyf11+l/E3bcCG8xsRHTWGcC7nV2HSCrKzkjjS2cM441vnsbHRvdm1tqd/OCfSzj7N9NYsCG5d9WtKa/RZqEDCCoavwQ8bGaLgGOAnwVUh0hKKsrN5PefPpaF/302j15/AuEwXDllBo/MXE84nHwXvQmHnfd2KAgOJJAgcPcF0c0+R7v7Re5eGUQdIqnOzJgwpJgnbz6RQd278N0n32H8z17l5kfmJdU5CFv37GVvY5hSBUGrtLFMROiRn81zXz6ZO68cy8ThPXhjRTnn3fEmj8xcH3RpHWJtReSIIR062rqgTigTkThjZlwwpg8XjOlDRfVIvvb3hXz3yXd4ZuFmbjlzGMcPLg66xEO2LwgGlSgIWqMgEJEP6Z6XxQNXj+OBt9Zy97TVXD5lBsf0LyI9ZPTtmsPl4/pz4pDuQZfZZmsrasjJSKNnfnbQpcQlbRoSkValhYzPnTKYqV+fxK3njCA9ZGSkhZi2opxP3TOTP05bnTDDXq+tqGFgcS6hkAVdSlxSj0BEPlJhTgY3nzaUm08bCsDexmY+dc8MbnthGf/70nKOG9iV00f24LzRvenfLQez+PuyXVtRwxG9dTGaA1EQiEi7ZGek8fcbJzB/wy5eX76dl5ds4/YXlnH7C8vIz06noSnM8J75/OzioziqX2HQ5bKrtoH3dtRw0TF9gy4lbikIRKTd0tNCjCvtxrjSbtx6zkiWb63izVUVLNiwix75WTyzcDMX3fUWFx3Tl9F9CxjVu4Dxg7oF0lv49+oduMOJQxN3Z3esKQhE5LCN6JXPiF7/2fTy5TOGcdvzS/nr7A08MS8yb0z/Ir521nCOH9SNuoZmGpvDlORnxTQc3J3H5myguEsmx/Qvitl6Ep25x/9ZhGVlZT5nzpygyxCRdqqub2LzrjrmrqvkVy+voKK6/gPPlxbncs6RvRhc0oULj+lLZnQcoMPdqdvUHOb5xVu5c+pKVm6v5lvnjuQLk4Yc1nsmIjOb6+5lB11OQSAinaGmvonnFm2hvLqe3Mw0ahua+cVLyz+0XHrImDi8hDH9i+hVkE16mtE9L4vddY3UNjTRpyiHU4aV4O6s2FZNUW4GPQuyCYedUMjYs7eRrzw6n38tL2dAt1y+ePpQLj22X0oeMaQgEJG419QcZnddI0u3VPH84i2kWeQQ1acXbqKiuuGAr+vfLYdte+ppaAoDsO87fnBJHtt276WqvonvnjeSq08cRGZ66h4lryAQkYRWU9/E+p211DU2s3X3XnoWZJGfncGrS7cxY81O3J1zjuzFnr2NbKysIxx2lm2tYnjPPC49rj/jB3ULugmBa2sQaGexiMSlLlnpHNG74EPzh/fM56ZJnV9PMkvdPpOIiAAKAhGRlKcgEBFJcQoCEZEUpyAQEUlxCgIRkRSnIBARSXEKAhGRFJcQZxab2W5gZYtZhcDuAzxuOd0dqOiAEvZf36Esd6DnWpvf1vYF2daDLZts7T3Ycm1pV2vz2tJ2fbaHRp8tDHT3koO+2t3j/gZMaevj/abnxGL9h7LcgZ5rbX472hdYW1OtvQdbri3tOlhbD9Refbb6bDvys23tliibhp5px+P9n4vF+g9luQM919r8trYvyLYebNlka+/BlmtLu1qb15n/lvXZHtpyyfbZfkhCbBo6VGY2x9sw4FIySKW2Qmq1N5XaCqnV3nhpa6L0CA7VlKAL6ESp1FZIrfamUlshtdobF21N6h6BiIgcXLL3CERE5CAUBCIiKU5BICKS4lIyCMwsZGY/NbM7zWzqytVkAAAH3ElEQVRy0PXEmplNMrM3zOxuM5sUdD2xZmZdzGyumZ0fdC2xZmZHRD/Xx83sC0HXE2tmdpGZ3WNm/zSzs4OuJ5bMbLCZ3Wdmj8d6XQkXBGZ2v5ltN7PF+80/18yWm9kqM/v2Qd7mQqAv0AhsjFWtHaGD2utANZBNHLe3g9oK8C3g77GpsuN0RHvdfam7fx74JBD4YYgfpYPa+5S7Xw9cDVwew3IPSwe1dY27XxfbSqN1JdpRQ2Z2KpEvtT+7++jovDRgBXAWkS+62cCVQBpw235vcW30VunufzSzx9390s6qv706qL0V7h42s57A/7n7pzur/vbooLYeTeS0/Wwi7X62c6pvv45or7tvN7OPA98Gfufuj3RW/e3VUe2Nvu5XwMPuPq+Tym+XDm5rzL+jEu7i9e4+3cxK95s9Hljl7msAzOyvwIXufhvwoc0DZrYRaIg+bI5dtYevI9rbQiWQFYs6O0IHfbanAV2AUUCdmT3v7uGYFn6IOuqzdfengafN7DkgboOggz5fA24HXojXEIAO/38bcwkXBAfQF9jQ4vFG4PiPWP4fwJ1mdgowPZaFxUi72mtmlwDnAEXA72JbWodrV1vd/XsAZnY10Z5QTKvreO39bCcBlxAJ+OdjWllstPf/7peAM4FCMxvq7nfHsrgO1t7Pthj4KTDWzL4TDYyYSJYgsFbmHXCbl7vXAp2y7S1G2tvefxAJv0TUrra+v4D7gx1fSqdo72f7OvB6rIrpBO1t72+B38aunJhqb1t3AJ+PXTn/kXA7iw9gI9C/xeN+wOaAaukMqdTeVGorqL3J3N64bWuyBMFsYJiZDTKzTOAK4OmAa4qlVGpvKrUV1N5kbm/ctjXhgsDMHgXeBkaY2UYzu87dm4AvAi8BS4G/u/uSIOvsKKnU3lRqK6i9ydzeRGtrwh0+KiIiHSvhegQiItKxFAQiIilOQSAikuIUBCIiKU5BICKS4hQEIiIpTkEgHc7MqjthHR9v45DUHbnOSWZ24iG8bqyZ3RudvtrM4mK8JzMr3X+Y5FaWKTGzFzurJgmGgkDiVnTY3la5+9PufnsM1vlR429NAtodBMB3gTsPqaCAuXs5sMXMTgq6FokdBYHElJndamazzWyRmf2oxfynLHIVsSVmdkOL+dVm9mMzmwlMMLP3zOxHZjbPzN4xs5HR5d7/ZW1mD5rZb83s32a2xswujc4Pmdld0XU8a2bP73tuvxpfN7Ofmdk04CtmdoGZzTSz+Wb2qpn1jA4p/Hngq2a2wMxOif5afiLavtmtfVmaWT5wtLsvbOW5gWY2Nfq3mWpmA6Lzh5jZjOh7/ri1HpZFrsL2nJktNLPFZnZ5dP646N9hoZnNMrP86C//N6J/w3mt9WrMLM3MftHis7qxxdNPAXF5DQvpIO6um24degOqo/dnA1OIjLoYAp4FTo0+1y16nwMsBoqjjx34ZIv3eg/4UnT6JuDe6PTVRC7EAvAg8Fh0HaOIjPkOcCmRoZlDQC8i12O4tJV6XwfuavG4K/856/5zwK+i0z8EvtFiuUeAk6PTA4Clrbz3acATLR63rPsZYHJ0+lrgqej0s8CV0enP7/t77ve+nwDuafG4EMgE1gDjovMKiIwwnAtkR+cNA+ZEp0uBxdHpG4DvR6ezgDnAoOjjvsA7Qf+70i12t2QZhlri09nR2/zo4zwiX0TTgS+b2cXR+f2j83cQuVDQE/u9z74htOcSGXu/NU955NoD71rkSmwAJwOPRedvNbN/fUStf2sx3Q/4m5n1JvLluvYArzkTGGX2/ujCBWaW7+5VLZbpDZQf4PUTWrTnIeB/W8y/KDr9CPDLVl77DvBLM/s58Ky7v2FmRwFb3H02gLvvgUjvAfidmR1D5O87vJX3Oxs4ukWPqZDIZ7IW2A70OUAbJAkoCCSWDLjN3f/4gZmRi6mcCUxw91oze53IpSUB9rr7/leNq4/eN3Pgf7P1LaZtv/u2qGkxfSeRS3o+Ha31hwd4TYhIG+o+4n3r+E/bDqbNA3+5+wozOw44D7jNzF4msgmntff4KrANGBOteW8ryxiRntdLrTyXTaQdkqS0j0Bi6SXgWjPLAzCzvmbWg8ivzcpoCIwETojR+t8EPhHdV9CTyM7etigENkWnJ7eYXwXkt3j8MpHRJAGI/uLe31Jg6AHW828iQxFDZBv8m9HpGUQ2/dDi+Q8wsz5Arbv/hUiP4VhgGdDHzMZFl8mP7vwuJNJTCAOfJXKN3P29BHzBzDKirx0e7UlApAfxkUcXSWJTEEjMuPvLRDZtvG1m7wCPE/kifRFIN7NFwP8Q+eKLhSeIXAxkMfBHYCawuw2v+yHwmJm9AVS0mP8McPG+ncXAl4Gy6M7Vd2nlalLuvozIZRXz938u+vpron+HzwJfic6/Bfiamc0ismmptZqPAmaZ2QLge8BP3L0BuJzIZVgXAq8Q+TV/FzDZzGYQ+VKvaeX97gXeBeZFDyn9I//pfZ0GPNfKayRJaBhqSWpmlufu1Ra5/uss4CR339rJNXwVqHL3e9u4fC5Q5+5uZlcQ2XF8YUyL/Oh6phO5yHplUDVIbGkfgSS7Z82siMhO3//p7BCI+gNwWTuWP47Izl0DdhE5oigQZlZCZH+JQiCJqUcgIpLitI9ARCTFKQhERFKcgkBEJMUpCEREUpyCQEQkxSkIRERS3P8HzXGAgykWkFMAAAAASUVORK5CYII=\n",
"text/plain": "<matplotlib.figure.Figure at 0x7f636b96bcf8>"
},
"metadata": {},
"output_type": "display_data"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "lr=2e-3",
"execution_count": 17,
"outputs": []
},
{
"metadata": {
"scrolled": false,
"trusted": true
},
"cell_type": "code",
"source": "learner.fit(lr, 1, wds=wd, use_clr=(32,5), cycle_len=5)",
"execution_count": 18,
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4572e0541c9040f49bd8760ef3ba2691",
"version_major": 2,
"version_minor": 0
},
"text/plain": "A Jupyter Widget"
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": " 0%| | 32/6872 [00:05<20:55, 5.45it/s, loss=10.4] \n 0%| | 33/6872 [00:05<20:39, 5.52it/s, loss=10.4]"
},
{
"name": "stderr",
"output_type": "stream",
"text": "Exception in thread Thread-4:\nTraceback (most recent call last):\n File \"/home/jhoward/anaconda3/lib/python3.6/threading.py\", line 916, in _bootstrap_inner\n self.run()\n File \"/home/jhoward/anaconda3/lib/python3.6/site-packages/tqdm/_tqdm.py\", line 144, in run\n for instance in self.tqdm_cls._instances:\n File \"/home/jhoward/anaconda3/lib/python3.6/_weakrefset.py\", line 60, in __iter__\n for itemref in self.data:\nRuntimeError: Set changed size during iteration\n\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "epoch trn_loss val_loss accuracy \n 0 4.780974 4.623935 0.24265 \n 1 4.581623 4.390652 0.264661 \n 2 4.489372 4.297412 0.274303 \n 3 4.413826 4.237745 0.280554 \n 4 4.360214 4.205443 0.28436 \n\n"
},
{
"data": {
"text/plain": "[4.2054429, 0.28435954039275879]"
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"scrolled": false,
"trusted": true
},
"cell_type": "code",
"source": "learner.fit(lr, 1, wds=wd, use_clr=(32,10), cycle_len=10)",
"execution_count": 20,
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "eb68e8d4a8a84b7fb9c1056bc532358c",
"version_major": 2,
"version_minor": 0
},
"text/plain": "A Jupyter Widget"
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": "epoch trn_loss val_loss accuracy \n 0 4.476262 4.28758 0.275657 \n 1 4.45527 4.267123 0.27728 \n 2 4.43806 4.23657 0.280728 \n 3 4.400014 4.207645 0.283538 \n 4 4.363546 4.181832 0.286277 \n 5 4.358071 4.155788 0.289054 \n 6 4.316609 4.137524 0.291174 \n 7 4.298223 4.118048 0.293109 \n 8 4.245867 4.106094 0.294576 \n 9 4.251366 4.095499 0.295898 \n\n"
},
{
"data": {
"text/plain": "[4.095499, 0.29589772825184546]"
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learner.save('lm')\nlearner.save_encoder('lm_enc')",
"execution_count": 21,
"outputs": []
},
{
"metadata": {
"scrolled": false,
"trusted": true
},
"cell_type": "code",
"source": "learner.fit(lr/2, 1, wds=wd, use_clr=(32,10), cycle_len=20)",
"execution_count": 22,
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f433bcd115694b30b70afadcb8ad6655",
"version_major": 2,
"version_minor": 0
},
"text/plain": "A Jupyter Widget"
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": "epoch trn_loss val_loss accuracy \n 0 4.29133 4.106632 0.294595 \n 1 4.302012 4.130576 0.291868 \n 2 4.290114 4.125361 0.292259 \n 3 4.299547 4.116803 0.293074 \n 4 4.279625 4.111819 0.293562 \n 5 4.27477 4.105453 0.294507 \n 6 4.270169 4.097149 0.295253 \n 7 4.24441 4.091094 0.295958 \n 8 4.266024 4.084123 0.296673 \n 9 4.229016 4.077878 0.297265 \n 10 4.234577 4.071088 0.297814 \n 11 4.201124 4.067428 0.298446 \n 12 4.211761 4.061707 0.299196 \n 13 4.206094 4.057941 0.299647 \n 14 4.192077 4.052402 0.30016 \n 15 4.177019 4.047873 0.300711 \n 16 4.182471 4.043671 0.301255 \n 17 4.157439 4.040637 0.301536 \n 18 4.181218 4.037622 0.301881 \n 19 4.167479 4.036257 0.302133 \n\n"
},
{
"data": {
"text/plain": "[4.0362568, 0.30213308110538678]"
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Classifier tokens"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "df_trn = pd.read_csv(CLAS_PATH/'train.csv', header=None, chunksize=chunksize)\ndf_val = pd.read_csv(CLAS_PATH/'test.csv', header=None, chunksize=chunksize)",
"execution_count": 9,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "tok_trn, trn_labels = get_all(df_trn, 1)\ntok_val, val_labels = get_all(df_val, 1)",
"execution_count": 10,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "0\n1\n0\n1\n"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "(CLAS_PATH/'tmp').mkdir(exist_ok=True)\n\nnp.save(CLAS_PATH/'tmp'/'tok_trn.npy', tok_trn)\nnp.save(CLAS_PATH/'tmp'/'tok_val.npy', tok_val)\n\nnp.save(CLAS_PATH/'tmp'/'trn_labels.npy', trn_labels)\nnp.save(CLAS_PATH/'tmp'/'val_labels.npy', val_labels)",
"execution_count": 11,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "tok_trn = np.load(CLAS_PATH/'tmp'/'tok_trn.npy')\ntok_val = np.load(CLAS_PATH/'tmp'/'tok_val.npy')\n\ntrn_labels = np.load(CLAS_PATH/'tmp'/'trn_labels.npy')\nval_labels = np.load(CLAS_PATH/'tmp'/'val_labels.npy')",
"execution_count": null,
"outputs": []
},
{
"metadata": {
"scrolled": true,
"trusted": true
},
"cell_type": "code",
"source": "freq = Counter(p for o in tok_trn for p in o)\nfreq.most_common(25)",
"execution_count": 13,
"outputs": [
{
"data": {
"text/plain": "[('the', 335844),\n ('.', 277583),\n (',', 275297),\n ('and', 163775),\n ('a', 162489),\n ('of', 145813),\n ('to', 135629),\n ('is', 110387),\n ('it', 95826),\n ('in', 93847),\n ('i', 86730),\n ('this', 75735),\n ('that', 73495),\n ('\"', 65053),\n (\"'s\", 62103),\n ('-', 52852),\n ('was', 50493),\n ('\\n\\n', 49832),\n ('as', 46849),\n ('for', 44290),\n ('with', 44076),\n ('movie', 43840),\n ('but', 42441),\n ('film', 40027),\n (')', 34632)]"
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "itos = pickle.load((LM_PATH/'tmp'/'itos.pkl').open('rb'))\nstoi = collections.defaultdict(lambda:0, {v:k for k,v in enumerate(itos)})\nlen(itos)",
"execution_count": 6,
"outputs": [
{
"data": {
"text/plain": "60002"
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "trn_clas = np.array([[stoi[o] for o in p] for p in tok_trn])\nval_clas = np.array([[stoi[o] for o in p] for p in tok_val])",
"execution_count": 21,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "np.save(LM_PATH/'tmp'/'trn_ids.npy', trn_clas)\nnp.save(LM_PATH/'tmp'/'val_ids.npy', val_clas)",
"execution_count": 23,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Classifier"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "trn_clas = np.load(LM_PATH/'tmp'/'trn_ids.npy')\nval_clas = np.load(LM_PATH/'tmp'/'val_ids.npy')",
"execution_count": 4,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "bptt,em_sz,nh,nl = 70,400,1150,3\nvs = len(itos)\nopt_fn = partial(optim.Adam, betas=(0.8, 0.99))\nbs = 48",
"execution_count": 11,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "trn_labels = np.squeeze(np.load(CLAS_PATH/'tmp'/'trn_labels.npy'))\nval_labels = np.squeeze(np.load(CLAS_PATH/'tmp'/'val_labels.npy'))",
"execution_count": 8,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "trn_labels -= trn_labels.min()\nval_labels -= val_labels.min()\nc=int(trn_labels.max())+1",
"execution_count": 9,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "trn_ds = TextDataset(trn_clas, trn_labels)\nval_ds = TextDataset(val_clas, val_labels)\ntrn_samp = SortishSampler(trn_clas, key=lambda x: len(trn_clas[x]), bs=bs//2)\nval_samp = SortSampler(val_clas, key=lambda x: len(val_clas[x]))\ntrn_dl = DataLoader(trn_ds, bs//2, transpose=True, num_workers=1, pad_idx=1, sampler=trn_samp)\nval_dl = DataLoader(val_ds, bs, transpose=True, num_workers=1, pad_idx=1, sampler=val_samp)\nmd = ModelData(PATH, trn_dl, val_dl)",
"execution_count": 12,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "dps = np.array([0.4,0.5,0.05,0.3,0.4])*0.5",
"execution_count": 13,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "m = get_rnn_classifer(bptt, 20*70, c, vs, emb_sz=em_sz, n_hid=nh, n_layers=nl, pad_token=1,\n layers=[em_sz*3, 50, c], drops=[dps[4], 0.1],\n dropouti=dps[0], wdrop=dps[1], dropoute=dps[2], dropouth=dps[3])",
"execution_count": 14,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn = RNN_Learner(md, TextModel(to_gpu(m)), opt_fn=opt_fn)\nlearn.reg_fn = partial(seq2seq_reg, alpha=2, beta=1)\nlearn.clip=25.\nlearn.metrics = [accuracy]",
"execution_count": 15,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "lr=3e-3\nlrm = 2.6\nlrs = np.array([lr/(lrm**4), lr/(lrm**3), lr/(lrm**2), lr/lrm, lr])",
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "wd = 1e-6\nlearn.load_encoder('lm_enc')",
"execution_count": 16,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn.freeze_to(-1)",
"execution_count": 17,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn.lr_find(lrs/1000)\nlearn.sched.plot()",
"execution_count": 60,
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ef6d96007c3b40808b64198c2ac726cb",
"version_major": 2,
"version_minor": 0
},
"text/plain": "A Jupyter Widget"
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": " 80%|███████▉ | 625/782 [01:39<00:24, 6.31it/s, loss=1.28] "
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn.fit(lrs, 1, wds=wd, cycle_len=1, use_clr=(8,3))",
"execution_count": 63,
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4f97330902094518adf87102cd20430b",
"version_major": 2,
"version_minor": 0
},
"text/plain": "A Jupyter Widget"
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": " 2%|▏ | 12/782 [00:02<02:13, 5.75it/s, loss=0.599]\n 2%|▏ | 14/782 [00:02<02:06, 6.09it/s, loss=0.563]"
},
{
"name": "stderr",
"output_type": "stream",
"text": "Exception in thread Thread-11:\nTraceback (most recent call last):\n File \"/home/jhoward/anaconda3/lib/python3.6/threading.py\", line 916, in _bootstrap_inner\n self.run()\n File \"/home/jhoward/anaconda3/lib/python3.6/site-packages/tqdm/_tqdm.py\", line 144, in run\n for instance in self.tqdm_cls._instances:\n File \"/home/jhoward/anaconda3/lib/python3.6/_weakrefset.py\", line 60, in __iter__\n for itemref in self.data:\nRuntimeError: Set changed size during iteration\n\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "epoch trn_loss val_loss accuracy \n 0 0.267757 0.203153 0.924017 \n\n"
},
{
"data": {
"text/plain": "[0.20315324, 0.92401694367303877]"
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn.save('clas_0')",
"execution_count": 64,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn.load('clas_0')",
"execution_count": 18,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn.unfreeze()",
"execution_count": 19,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "lr=2e-3\nlrs = np.array([lr/(lrm**4), lr/(lrm**3), lr/(lrm**2), lr/lrm, lr])",
"execution_count": 21,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn.fit(lrs, 1, wds=wd, cycle_len=20, use_clr=(32,10))",
"execution_count": 22,
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f39b72fd0e81450dbfed7f4e8ff384c6",
"version_major": 2,
"version_minor": 0
},
"text/plain": "A Jupyter Widget"
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": " 0%| | 0/1042 [00:00<?, ?it/s] \n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "Exception in thread Thread-4:\nTraceback (most recent call last):\n File \"/home/jhoward/anaconda3/lib/python3.6/threading.py\", line 916, in _bootstrap_inner\n self.run()\n File \"/home/jhoward/anaconda3/lib/python3.6/site-packages/tqdm/_tqdm.py\", line 144, in run\n for instance in self.tqdm_cls._instances:\n File \"/home/jhoward/anaconda3/lib/python3.6/_weakrefset.py\", line 60, in __iter__\n for itemref in self.data:\nRuntimeError: Set changed size during iteration\n\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "epoch trn_loss val_loss accuracy \n 0 0.262144 0.194416 0.927967 \n 1 0.216391 0.178578 0.933277 \n 2 0.207458 0.168077 0.938316 \n 3 0.174846 0.169113 0.939435 \n 4 0.149987 0.1651 0.942075 \n 5 0.10919 0.189178 0.937436 \n 6 0.10541 0.194798 0.939507 \n 7 0.090148 0.19878 0.938644 \n 8 0.054718 0.218309 0.939859 \n 9 0.053275 0.257329 0.931718 \n 10 0.050572 0.25061 0.935765 \n 11 0.045674 0.241273 0.941067 \n 12 0.031045 0.289027 0.934925 \n 13 0.036667 0.281847 0.938508 \n 14 0.023639 0.273566 0.939955 \n 15 0.025651 0.278519 0.937508 \n 16 0.02167 0.274886 0.940147 \n 17 0.01852 0.293825 0.940227 \n 18 0.012914 0.291758 0.939467 \n 19 0.024635 0.295659 0.938716 \n\n"
},
{
"data": {
"text/plain": "[0.29565862, 0.93871561041720308]"
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn.save('clas_1')",
"execution_count": 23,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.6.4",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"toc": {
"threshold": 4,
"number_sections": true,
"toc_cell": false,
"toc_window_display": false,
"toc_section_display": "block",
"sideBar": true,
"navigate_menu": true,
"moveMenuLeft": true,
"widenNotebook": false,
"colors": {
"hover_highlight": "#DAA520",
"selected_highlight": "#FFD700",
"running_highlight": "#FF0000",
"wrapper_background": "#FFFFFF",
"sidebar_border": "#EEEEEE",
"navigate_text": "#333333",
"navigate_num": "#000000"
},
"nav_menu": {
"height": "86px",
"width": "252px"
}
},
"gist": {
"id": "",
"data": {
"description": "fastai.text imdb example",
"public": true
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment