Skip to content

Instantly share code, notes, and snippets.

@ohmeow
Created March 8, 2018 20:28
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ohmeow/314434261415a164214e4642c0dafc94 to your computer and use it in GitHub Desktop.
Save ohmeow/314434261415a164214e4642c0dafc94 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%reload_ext autoreload\n",
"%autoreload 2\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from fastai.text import *\n",
"import pdb\n",
"\n",
"import spacy\n",
"spacy_en = spacy.load('en')\n",
"\n",
"# pandas and plotting config\n",
"pd.set_option('display.max_rows', 1000)\n",
"pd.set_option('display.max_columns', 1000)\n",
"pd.set_option('display.max_colwidth', -1)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"imdbEr.txt imdb.vocab \u001b[0m\u001b[01;34mmodels\u001b[0m/ README \u001b[01;34mtest\u001b[0m/ \u001b[01;34mtmp\u001b[0m/ \u001b[01;34mtrain\u001b[0m/\r\n"
]
}
],
"source": [
"PATH = 'data/aclImdb'\n",
"TRN_PATH = f'{PATH}/train'\n",
"VAL_PATH = f'{PATH}/test'\n",
"\n",
"os.makedirs(f'{PATH}/models', exist_ok=True)\n",
"os.makedirs(f'{PATH}/tmp', exist_ok=True)\n",
"\n",
"%ls {PATH}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"What is in the training folder?"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['0_0.txt',\n",
" '0_3.txt',\n",
" '0_9.txt',\n",
" '10000_0.txt',\n",
" '10000_4.txt',\n",
" '10000_8.txt',\n",
" '1000_0.txt',\n",
" '10001_0.txt',\n",
" '10001_10.txt',\n",
" '10001_4.txt']"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trn_files = !ls {TRN_PATH}/all\n",
"trn_files[:10]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"I admit, the great majority of films released before say 1933 are just not for me. Of the dozen or so \"major\" silents I have viewed, one I loved (The Crowd), and two were very good (The Last Command and City Lights, that latter Chaplin circa 1931).<br /><br />So I was apprehensive about this one, and humor is often difficult to appreciate (uh, enjoy) decades later. I did like the lead actors, but thought little of the film.<br /><br />One intriguing sequence. Early on, the guys are supposed to get \"de-loused\" and for about three minutes, fully dressed, do some schtick. In the background, perhaps three dozen men pass by, all naked, white and black (WWI ?), and for most, their butts, part or full backside, are shown. Was this an early variation of beefcake courtesy of Howard Hughes?\n"
]
}
],
"source": [
"for line in open(f'{TRN_PATH}/all/0_0.txt', encoding='utf-8'):\n",
" print(line)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"What does a review look like?"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"I have to say when a name like Zombiegeddon and an atom bomb on the front cover I was expecting a flat out chop-socky fung-ku, but what I got instead was a comedy. So, it wasn't quite was I was expecting, but I really liked it anyway! The best scene ever was the main cop dude pulling those kids over and pulling a Bad Lieutenant on them!! I was laughing my ass off. I mean, the cops were just so bad! And when I say bad, I mean The Shield Vic Macky bad. But unlike that show I was laughing when they shot people and smoked dope.<br /><br />Felissa Rose...man, oh man. What can you say about that hottie. She was great and put those other actresses to shame. She should work more often!!!!! I also really liked the fight scene outside of the building. That was done really well. Lots of fighting and people getting their heads banged up. FUN! Last, but not least Joe Estevez and William Smith were great as the...well, I wasn't sure what they were, but they seemed to be having fun and throwing out lines. I mean, some of it didn't make sense with the rest of the flick, but who cares when you're laughing so hard! All in all the film wasn't the greatest thing since sliced bread, but I wasn't expecting that. It was a Troma flick so I figured it would totally suck. It's nice when something surprises you but not totally sucking.<br /><br />Rent it if you want to get stoned on a Friday night and laugh with your buddies. Don't rent it if you are an uptight weenie or want a zombie movie with lots of flesh eating.<br /><br />P.S. Uwe Boil was a nice touch.\""
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"review = !cat {TRN_PATH}/all/{trn_files[6]}\n",
"review[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"How many words in training and validation datasets"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"17486581\r\n"
]
}
],
"source": [
"!find {TRN_PATH}/all -name '*.txt' | xargs cat | wc -w"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5686719\r\n"
]
}
],
"source": [
"!find {VAL_PATH}/all -name '*.txt' | xargs cat | wc -w"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Tokenize using the new fastai.text package"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"I have to say when a name like Zombiegeddon and an atom bomb on the front cover I was expecting a flat out chop - socky fung - ku , but what I got instead was a comedy . So , it was n't quite was I was expecting , but I really liked it anyway ! The best scene ever was the main cop dude pulling those kids over and pulling a Bad Lieutenant on them ! ! I was laughing my ass off . I mean , the cops were just so bad ! And when I say bad , I mean The Shield Vic Macky bad . But unlike that show I was laughing when they shot people and smoked dope . \\n\\n Felissa Rose ... man , oh man . What can you say about that hottie . She was great and put those other actresses to shame . She should work more often ! ! ! ! ! I also really liked the fight scene outside of the building . That was done really well . Lots of fighting and people getting their heads banged up . FUN ! Last , but not least Joe Estevez and William Smith were great as the ... well , I was n't sure what they were , but they seemed to be having fun and throwing out lines . I mean , some of it did n't make sense with the rest of the flick , but who cares when you 're laughing so hard ! All in all the film was n't the greatest thing since sliced bread , but I was n't expecting that . It was a Troma flick so I figured it would totally suck . It 's nice when something surprises you but not totally sucking . \\n\\n Rent it if you want to get stoned on a Friday night and laugh with your buddies . Do n't rent it if you are an uptight weenie or want a zombie movie with lots of flesh eating . \\n\\n P.S. Uwe Boil was a nice touch .\""
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"' '.join(spacy_tok(review[0]))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(25000, 25000)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trn_docs, trn_labels = texts_labels_from_folders(TRN_PATH, ['neg', 'pos'])\n",
"val_docs, val_labels = texts_labels_from_folders(VAL_PATH, ['neg', 'pos'])\n",
"\n",
"len(trn_docs), len(val_docs)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 59 s, sys: 196 ms, total: 59.2 s\n",
"Wall time: 59.2 s\n",
"CPU times: user 54 s, sys: 200 ms, total: 54.2 s\n",
"Wall time: 54.2 s\n"
]
}
],
"source": [
"f_tok = Tokenizer()\n",
"%time trn_docs_pp = f_tok.proc_all(trn_docs)\n",
"%time val_docs_pp = f_tok.proc_all(val_docs)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"from collections import Counter, defaultdict\n",
"\n",
"class Vocab:\n",
" def __init__(self, tokens, min_freq=1, max_size=None, \n",
" specials=['<unk>', '<pad>', '<bos>', '<eos>'], unk_idx=0):\n",
" self.min_freq = max(min_freq, 1)\n",
" self.specials = specials\n",
" self.unk_idx = unk_idx\n",
" \n",
" self.tokens = list(specials)\n",
" self.max_size = None if max_size is None else max_size + len(self.tokens)\n",
" \n",
" self.token_freqs = Counter(tokens)\n",
" for t in self.specials: del self.token_freqs[t]\n",
" \n",
" # sort by frequency, then alphabetically\n",
" self.token_freqs = sorted(self.token_freqs.items(), key=lambda tup: tup[0])\n",
" self.token_freqs.sort(key=lambda tup: tup[1], reverse=True)\n",
" \n",
" for token, freq in self.token_freqs:\n",
" if freq < self.min_freq or len(self.tokens) == self.max_size:\n",
" break\n",
" self.tokens.append(token)\n",
" \n",
" self.vocab_stoi = defaultdict(lambda x: self.unk_idx) # default is <unk>\n",
" self.vocab_stoi.update({ tok: i for i, tok in enumerate(self.tokens) })\n",
" \n",
" def stoi(self, token):\n",
" return self.vocab_stoi.get(token, self.unk_idx)\n",
" \n",
" def itos(self, idx):\n",
" return self.tokens[idx]\n",
" \n",
" def token_freq(self, token):\n",
" return self.token_freqs.get(token, 0)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"class LanguageDataset(torch.utils.data.Dataset):\n",
" def __init__(self, docs, newline_bos=True, newline_eos=True, vocab=None, min_freq=1, max_size=None):\n",
" self.tokens = []\n",
" for d in docs:\n",
" if newline_bos: self.tokens.append('<bos>')\n",
" self.tokens += d\n",
" if newline_eos: self.tokens.append('<eos>')\n",
" \n",
" if (vocab):\n",
" self.vocab = vocab\n",
" else:\n",
" self.vocab = Vocab(self.tokens, min_freq, max_size)\n",
" \n",
" self.data = np.array([[ self.vocab.stoi(t) for t in self.tokens ]])\n",
" \n",
" def __getitem__(self, idx):\n",
" return self.data[idx]\n",
"\n",
" def __len__(self):\n",
" return len(self.data)\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 9.47 s, sys: 124 ms, total: 9.59 s\n",
"Wall time: 9.59 s\n",
"CPU times: user 12 ms, sys: 0 ns, total: 12 ms\n",
"Wall time: 12.5 ms\n"
]
}
],
"source": [
"%time trn_ds = LanguageDataset(trn_docs_pp + val_docs_pp, newline_bos=False, min_freq=10)\n",
"%time val_ds = LanguageDataset(val_docs_pp[:100], newline_bos=False, vocab=trn_ds.vocab)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"14101165 14101165 29563 1\n",
"26210 26210 29563 1\n"
]
}
],
"source": [
"print(len(trn_ds[0]), len(trn_ds.tokens), len(trn_ds.vocab.tokens), len(trn_ds))\n",
"print(len(val_ds[0]), len(val_ds.tokens), len(val_ds.vocab.tokens), len(val_ds))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 16 193 33 407 45 4 3469 59 17 28]\n",
"['i', 'ca', \"n't\", 'understand', 'all', 'the', 'hype', 'about', 'this', 'movie']\n"
]
}
],
"source": [
"print(trn_ds[0][:10])\n",
"print([ trn_ds.vocab.itos(idx) for idx in trn_ds[0][:10] ])"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"bsz = 64\n",
"bptt = 70"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"trn_dl = LanguageModelLoader(trn_ds[0], bsz, bptt)\n",
"val_dl = LanguageModelLoader(val_ds[0], bsz, bptt)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"# next(iter(trn_dl))"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"md = LanguageModelData(PATH, 1, len(trn_ds.vocab.tokens), trn_dl, val_dl, bptt=bptt, min_freq=10)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(3146, 29563, 1, 14101165)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(md.trn_dl), md.nt, len(trn_ds), len(trn_ds.tokens)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([68, 64])\n",
"torch.Size([4352])\n"
]
},
{
"data": {
"text/plain": [
"(None, None)"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"batch = next(iter(md.trn_dl))\n",
"print(batch[0].size()), print(batch[1].size())\n",
"\n",
"# batch"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"em_sz = 200 # size of each embedding vector\n",
"nh = 500 # number of hidden activations per layer\n",
"nl = 3 # number of layers"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"opt_fn = partial(optim.Adam, betas=(0.7, 0.99))"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"learner = md.get_model(opt_fn, em_sz, nh, nl,\n",
" dropouti=0.05, dropout=0.05, wdrop=0.1, dropoute=0.02, dropouth=0.05)\n",
"learner.reg_fn = partial(seq2seq_reg, alpha=2, beta=1)\n",
"learner.clip=0.3"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"# batch_iter = iter(md.val_dl)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"# b = next(batch_iter) # x, y\n",
"# p = learner.model(V(b[0])) # predictions\n",
"\n",
"# b[0].shape, b[1].shape, p[0].shape"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "05eac8df390c498c98bf99483a1b5a66",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
"<p>\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
" Widgets Documentation</a> for setup instructions.\n",
"</p>\n",
"<p>\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"</p>\n"
],
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=15), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch trn_loss val_loss \n",
" 0 4.697988 4.57248 \n",
" 1 4.500699 4.348201 \n",
" 2 4.371368 4.243143 \n",
" 3 4.428565 4.301117 \n",
" 4 4.335875 4.188483 \n",
" 5 4.251823 4.102633 \n",
" 6 4.216845 4.063244 \n",
" 7 4.34847 4.217101 \n",
" 8 4.313081 4.169003 \n",
" 9 4.266176 4.126935 \n",
" 10 4.240518 4.092284 \n",
" 11 4.22108 4.034062 \n",
" 12 4.149041 3.991385 \n",
" 13 4.129282 3.971003 \n",
" 14 4.112093 3.967932 \n",
"\n"
]
},
{
"data": {
"text/plain": [
"[3.9679315]"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learner.fit(3e-3, 4, wds=1e-6, cycle_len=1, cycle_mult=2)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"learner.save_encoder('imdb_adam1_enc_full')\n",
"# learner.load_encoder('imdb_adam1_enc')"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "593ad544f97f42508c364f439179e24f",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
"<p>\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
" Widgets Documentation</a> for setup instructions.\n",
"</p>\n",
"<p>\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"</p>\n"
],
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch trn_loss val_loss \n",
" 0 4.269407 4.146606 \n",
" 1 4.274711 4.125793 \n",
" 2 4.248354 4.102614 \n",
" 3 4.216576 4.070951 \n",
" 4 4.189146 4.038531 \n",
" 5 4.177811 4.007429 \n",
" 6 4.126941 3.989224 \n",
" 7 4.088423 3.938474 \n",
" 8 4.064484 3.912925 \n",
" 9 4.073764 3.916723 \n",
" 10 4.273544 4.104843 \n",
" 11 4.233156 4.096601 \n",
" 12 4.218732 4.076135 \n",
" 13 4.183688 4.045218 \n",
" 14 4.170473 4.009149 \n",
" 15 4.138231 3.98437 \n",
" 16 4.115118 3.936274 \n",
" 17 4.087664 3.909125 \n",
" 18 4.050396 3.890166 \n",
" 19 4.037233 3.894554 \n",
"\n"
]
},
{
"data": {
"text/plain": [
"[3.8945541]"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# learner.fit(3e-3, 4, wds=1e-6, cycle_len=10, cycle_save_name='imdb_adam2_4_10')\n",
"learner.fit(3e-3, 2, wds=1e-6, cycle_len=10, cycle_save_name='imdb_adam2_c2_cl10_full')"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"\n",
"learner.save_encoder('imdb_adam2_enc_full')\n",
"# learner.load_encoder('imdb_adam2_enc')`"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "62c114139a0449adb92082062fd8d89a",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
"<p>\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
" Widgets Documentation</a> for setup instructions.\n",
"</p>\n",
"<p>\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"</p>\n"
],
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch trn_loss val_loss \n",
" 0 4.066651 3.888733 \n",
" 1 4.071485 3.899539 \n",
" 2 4.091433 3.904781 \n",
" 3 4.057991 3.891943 \n",
" 4 4.048481 3.880265 \n",
" 5 4.042373 3.879733 \n",
" 6 4.057358 3.877355 \n",
" 7 4.06226 3.876944 \n",
" 8 4.044253 3.882051 \n",
" 9 4.033689 3.868109 \n",
" 10 4.057388 3.864538 \n",
" 11 4.022268 3.86107 \n",
" 12 4.013979 3.862605 \n",
" 13 4.011391 3.857065 \n",
" 14 4.046247 3.853321 \n",
" 15 4.010649 3.847025 \n",
" 16 4.044526 3.84362 \n",
" 17 4.012577 3.855527 \n",
" 18 4.038569 3.851449 \n",
" 19 4.011175 3.884566 \n",
"\n"
]
},
{
"data": {
"text/plain": [
"[3.8845663]"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learner.fit(3e-4, 1, wds=1e-6, cycle_len=20, cycle_save_name='imdb_adam3_c1_cl20_full')"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"learner.load_cycle('imdb_adam3_c1_cl20_full', 0) # load best cycle"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the sentiment analysis section, we'll just need half of the language model - the encoder, so we save that part."
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"# learner.save_encoder('imdb_adam3_enc_full')\n",
"learner.load_encoder('imdb_adam3_enc_full')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Language modeling accuracy is generally measured using the metric perplexity, which is simply exp() of the loss function we used."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"math.exp(4.115031 ) # what I got when use a validation dataset"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"ename": "AttributeError",
"evalue": "Can't pickle local object 'Vocab.__init__.<locals>.<lambda>'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-28-9483b10d351a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdump\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrn_ds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvocab\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'{PATH}/models/vocab_full.pkl'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m'wb'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m: Can't pickle local object 'Vocab.__init__.<locals>.<lambda>'"
]
}
],
"source": [
"pickle.dump(trn_ds.vocab, open(f'{PATH}/models/vocab_full.pkl','wb'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can play around with our language model a bit to check it seems to be working OK. First, let's create a short bit of text to 'prime' a set of predictions. We'll use our torchtext field to numericalize it so we can feed it to our language model."
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"vocab =trn_ds.vocab # = pickle.load(open(f'{PATH}/models/vocab.pkl', 'rb'))"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'. the monster was crazy and i cried'"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# create a short bit of text to \"prime\" the precitions, then use torchtext to numericalize it\n",
"# so we can feed it into our language model\n",
"m = learner.model\n",
"# ss = \"\"\". I laughed so hard when\"\"\"\n",
"ss = \"\"\". The monster was crazy and I cried \"\"\"\n",
"s = [f_tok.proc_text(ss)]\n",
"t = np.array([ vocab.stoi(tok) if tok in vocab.tokens else 0 for tok in s[0] ]) # TEXT.numericalize(s)\n",
"' '.join(s[0])"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Variable containing:\n",
" 6 4 769 24 952 7 16 3600\n",
"[torch.cuda.LongTensor of size 1x8 (GPU 0)]"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t=T(t)\n",
"t = V(t.unsqueeze(0).cuda())\n",
"t"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We haven't yet added methods to make it easy to test a language model, so we'll need to manually go through the steps."
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"m[0].bs = 1 # set batch size = 1\n",
"m.eval() # turn-off dropout\n",
"m.reset() # reset hidden state\n",
"res, *_ = m(t) # get predictions from model\n",
"m[0].bs = bsz # put batch size back to what it was"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"8"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"res[-1].size() # the prediction based on the full sentence; the last prediction\n",
"len(res) # the number of words in \"t\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's see what the top 10 predictions were for the next word after our short text:"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[',', 'at', 'out', '.', 'and', 'by', 'in', 'when', 'for', 'a']"
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# top 10 predictions for next word\n",
"nexts = torch.topk(res[-1], 10)[1] # return the 10 indexes of the top 10 predictions\n",
"[vocab.tokens[o] for o in to_np(nexts)] # [TEXT.vocab.itos[o] for o in to_np(nexts)]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"...and let's see if our model can generate a bit more text all by itself!"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
". The monster was crazy and I cried \n",
"\n",
", and the film is a bit of a disappointment . the film is a bit of a let down , but it 's not a bad film . it 's a very good film , but it 's not a bad film . <eos> i have seen this movie ...\n"
]
}
],
"source": [
"# try to generate more text\n",
"print(ss, \"\\n\")\n",
"\n",
"for i in range(50):\n",
" n = res[-1].topk(2)[1]\n",
" n = n[1] if n.data[0] == 0 else n[0]\n",
" print(vocab.itos(n.data[0]), end=' ') #print(TEXT.vocab.itos[n.data[0]], end=' ')\n",
" res, *_ = m(n[0].unsqueeze(0))\n",
" \n",
"print('...')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment