Skip to content

Instantly share code, notes, and snippets.

@bharadwaj6
Created May 25, 2018 21:27
Show Gist options
  • Save bharadwaj6/8457ff541731f913f2d26eedfef4d803 to your computer and use it in GitHub Desktop.
Save bharadwaj6/8457ff541731f913f2d26eedfef4d803 to your computer and use it in GitHub Desktop.
Telugu_Language_Model
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import pathlib\n",
"\n",
"from fastai.text import *\n",
"\n",
"import numpy as np\n",
"import pandas as pd"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data Preparation"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"BOS = 'xbos' # beginning-of-sentence tag\n",
"FLD = 'xfld' # data field tag\n",
"\n",
"PATH = pathlib.Path(\"data/teluguwiki/data\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"LM_PATH=Path('data/telugu_lm/')\n",
"LM_PATH.mkdir(exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"471\n"
]
},
{
"data": {
"text/plain": [
"['data/teluguwiki/data/AC/wiki_97',\n",
" 'data/teluguwiki/data/AC/wiki_00',\n",
" 'data/teluguwiki/data/AC/wiki_69',\n",
" 'data/teluguwiki/data/AC/wiki_14',\n",
" 'data/teluguwiki/data/AC/wiki_20']"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"LANG_FILENAMES = [str(f) for f in PATH.rglob(\"*/*\")]\n",
"print(len(LANG_FILENAMES))\n",
"LANG_FILENAMES[0:5]"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"LANG_TEXT = []\n",
"for i in LANG_FILENAMES:\n",
" for line in open(i, encoding='utf-8'):\n",
" LANG_TEXT.append(json.loads(line))\n",
" \n",
"LANG_TEXT = pd.DataFrame(LANG_TEXT)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"LANG_TEXT.to_csv(f\"{LM_PATH}/Wiki_Telugu_Corpus.csv\", index=False)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"LANG_TEXT = pd.read_csv(f\"{LM_PATH}/Wiki_Telugu_Corpus.csv\")"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"(LANG_TEXT.assign(labels = 0)\n",
" .pipe(lambda x: x[['labels', 'text']])\n",
" .to_csv(f\"{LM_PATH}/Wiki_Telugu_Corpus2.csv\", header=None, index=False))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Some statistics of Telugu Wikipedia"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"# Getting rid of the title name in the text field\n",
"def split_title_from_text(text):\n",
" words = text.split(\"\\n\\n\")\n",
" if len(words) >= 2:\n",
" return ''.join(words[1:])\n",
" else:\n",
" return ''.join(words)\n",
" \n",
"LANG_TEXT['text'] = LANG_TEXT['text'].apply(lambda x: split_title_from_text(x))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Number of documents"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(69001, 4)"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"LANG_TEXT.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Number of words in all the documents"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"22174830"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"LANG_TEXT['text'].apply(lambda x: len(x.split(\" \"))).sum()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Number of unique tokens across documents"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2023529"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(set(''.join(LANG_TEXT['text'].values).split(\" \")))"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"def get_texts(df, n_lbls=1):\n",
" labels = df.iloc[:,range(n_lbls)].values.astype(np.int64)\n",
" texts = f'\\n{BOS} {FLD} 1 ' + df[n_lbls].astype(str)\n",
" for i in range(n_lbls+1, len(df.columns)): texts += f' {FLD} {i-n_lbls} ' + df[i].astype(str)\n",
" #texts = texts.apply(fixup).values.astype(str)\n",
"\n",
" tok = Tokenizer().proc_all_mp(partition_by_cores(texts)) # splits the list into sublists for processing by each core\n",
" # Lower and upper case is inside the tokenizer\n",
" return tok, list(labels)\n",
"\n",
"def get_all(df, n_lbls):\n",
" tok, labels = [], []\n",
" for i, r in enumerate(df):\n",
" print(i)\n",
" #pdb.set_trace()\n",
" tok_, labels_ = get_texts(r, n_lbls)\n",
" tok += tok_;\n",
" labels += labels_\n",
" return tok, labels"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"LANG_TEXT = pd.read_csv(f\"{LM_PATH}/Wiki_Telugu_Corpus2.csv\", header=None)#, chunksize=5000)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"trn_texts,val_texts = sklearn.model_selection.train_test_split(\n",
" LANG_TEXT, test_size=0.1) # split the data into train and validation sets"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(42)\n",
"trn_idx = np.random.permutation(len(trn_texts)) # generate a random ordering\n",
"val_idx = np.random.permutation(len(val_texts))\n",
"\n",
"df_trn = trn_texts.iloc[trn_idx,:] # sort things randomly\n",
"df_val = val_texts.iloc[val_idx,:] # sort things randomly\n",
"\n",
"df_trn.columns = ['labels', 'text']\n",
"df_val.columns = ['labels', 'text']\n",
"\n",
"df_trn.to_csv(LM_PATH/'train.csv', header=False, index=False)\n",
"df_val.to_csv(LM_PATH/'test.csv', header=False, index=False) # saving the data in our new format to disk"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"chunksize = 10000\n",
"df_trn = pd.read_csv(LM_PATH/'train.csv', header=None, chunksize=chunksize)\n",
"df_val = pd.read_csv(LM_PATH/'test.csv', header=None, chunksize=chunksize)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n",
"1\n",
"2\n",
"3\n",
"4\n",
"5\n",
"0\n"
]
}
],
"source": [
"tok_trn, trn_labels = get_all(df_trn, 1)\n",
"tok_val, val_labels = get_all(df_val, 1)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"# create a tmp directory to store the upcoming numpy arrays\n",
"(LM_PATH/'tmp').mkdir(exist_ok=True)\n",
"\n",
"# save the train and validation tokens in the tmp directories\n",
"np.save(LM_PATH/'tmp'/'tok_trn.npy', tok_trn)\n",
"np.save(LM_PATH/'tmp'/'tok_val.npy', tok_val)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"tok_trn = np.load(LM_PATH/'tmp'/'tok_trn.npy')\n",
"tok_val = np.load(LM_PATH/'tmp'/'tok_val.npy')"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[(',', 1182001),\n",
" ('\\n', 563593),\n",
" ('\\n\\n', 365503),\n",
" ('నుండి', 315059),\n",
" ('ఉన్నాయి.', 257793),\n",
" ('దూరంలో', 213086),\n",
" ('గ్రామం', 213014),\n",
" ('ఉంది.', 212279),\n",
" ('10', 185585),\n",
" ('గ్రామంలో', 160000),\n",
" ('\"', 153647),\n",
" ('ఈ', 143467),\n",
" ('మరియు', 141523),\n",
" ('కి.మీ.', 130526),\n",
" ('(', 130198),\n",
" (')', 127110),\n",
" ('5', 123026),\n",
" ('కేంద్రం', 115188),\n",
" ('సమీప', 112802),\n",
" ('.', 101798),\n",
" ('ఒక', 88165),\n",
" ('సౌకర్యం', 79118),\n",
" ('ద్వారా', 75279),\n",
" ('కూడా', 74616),\n",
" ('పైబడిన', 72048)]"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Identify the most common tokens and numericalizing the text\n",
"freq = Counter(p for o in tok_trn for p in o) \n",
"freq.most_common(25)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"# Truncating our vocab to ignore the rare words\n",
"max_vocab = 60000\n",
"min_freq = 5\n",
"\n",
"itos = [o for o,c in freq.most_common(max_vocab) if c>min_freq] # getting rid of the rare words\n",
"itos.insert(0, '_pad_') # \n",
"itos.insert(0, '_unk_') # itos is the list of all the strings in the vocab"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"60002"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# creating a index-key dictionary for our vocabulary\n",
"stoi = collections.defaultdict(lambda:0, {v:k for k,v in enumerate(itos)})\n",
"len(itos)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"# creating a index representation for our train and validation dataset\n",
"trn_lm = np.array([[stoi[o] for o in p] for p in tok_trn])\n",
"val_lm = np.array([[stoi[o] for o in p] for p in tok_val])"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"# saving our indexed representation of our dataset to disk\n",
"# we also save the index-word mapping to retrieve the complete text representation from these numpy arrays\n",
"np.save(LM_PATH/'tmp'/'trn_ids.npy', trn_lm)\n",
"np.save(LM_PATH/'tmp'/'val_ids.npy', val_lm)\n",
"pickle.dump(itos, open(LM_PATH/'tmp'/'itos.pkl', 'wb'))"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"# Loading the indexed representation of our dataset from disk\n",
"# we also load the index-word mapping to to help us convert the indexes to word datasets, if need be.\n",
"trn_lm = np.load(LM_PATH/'tmp'/'trn_ids.npy')\n",
"val_lm = np.load(LM_PATH/'tmp'/'val_ids.npy')\n",
"itos = pickle.load(open(LM_PATH/'tmp'/'itos.pkl', 'rb'))"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(60002, 52100)"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# checking vocabulary size\n",
"vs=len(itos)\n",
"vs,len(trn_lm)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Setup"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"# ! wget -nH -r -np http://files.fast.ai/models/wt103/\n",
"# mv models/ {LM_PATH}"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"em_sz,nh,nl = 400,1150,3\n",
"\n",
"PRE_PATH = LM_PATH/'models'/'wt103'\n",
"PRE_LM_PATH = PRE_PATH/'fwd_wt103.h5'\n",
"\n",
"itos2 = pickle.load((PRE_PATH/'itos_wt103.pkl').open('rb')) # mapping the itos from wiki to our own mapping\n",
"stoi2 = collections.defaultdict(lambda:-1, {v:k for k,v in enumerate(itos2)})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# we train from scratch so these are unused\n",
"# wgts = torch.load(PRE_LM_PATH, map_location=lambda storage, loc: storage)\n",
"\n",
"# enc_wgts = to_np(wgts['0.encoder.weight'])\n",
"# row_m = enc_wgts.mean(0)\n",
"\n",
"# wgts['0.encoder.weight'] = T(new_w)\n",
"# wgts['0.encoder_with_dropout.embed.weight'] = T(np.copy(new_w))\n",
"# wgts['1.decoder.weight'] = T(np.copy(new_w))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Language Model"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"wd=1e-7\n",
"bptt=70\n",
"bs=52\n",
"opt_fn = partial(optim.Adam, betas=(0.8, 0.99))"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"em_sz,nh,nl = 400,1150,3"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"trn_dl = LanguageModelLoader(np.concatenate(trn_lm), bs, bptt)\n",
"val_dl = LanguageModelLoader(np.concatenate(val_lm), bs, bptt)\n",
"md = LanguageModelData(PATH, 1, vs, trn_dl, val_dl, bs=bs, bptt=bptt)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"drops = np.array([0.25, 0.1, 0.2, 0.02, 0.15])*0.7 # if you're overfitting, increase this. Underfitting? decrease this."
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"learner= md.get_model(opt_fn, em_sz, nh, nl, \n",
" dropouti=drops[0], dropout=drops[1], wdrop=drops[2], dropoute=drops[3], dropouth=drops[4])\n",
"\n",
"learner.metrics = [accuracy]\n",
"learner.unfreeze()"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6f8bf3e319054d1b95146542e461a3dd",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
"<p>\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
" Widgets Documentation</a> for setup instructions.\n",
"</p>\n",
"<p>\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"</p>\n"
],
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" 70%|███████ | 3969/5657 [13:44<05:50, 4.81it/s, loss=14.3]"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7fa10b07e4e0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#find suitable learning rates\n",
"learner.lr_find(1e-07, 1e2)\n",
"learner.sched.plot()"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"lr = 1e-3\n",
"lrs = lr"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7a1e2d12bd954b3eab51080184821a40",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
"<p>\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
" Widgets Documentation</a> for setup instructions.\n",
"</p>\n",
"<p>\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"</p>\n"
],
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" 0%| | 5/5657 [00:01<29:30, 3.19it/s, loss=11] \n",
" 0%| | 6/5657 [00:01<28:00, 3.36it/s, loss=11]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Exception in thread Thread-11:\n",
"Traceback (most recent call last):\n",
" File \"/opt/conda/envs/fastai/lib/python3.6/threading.py\", line 916, in _bootstrap_inner\n",
" self.run()\n",
" File \"/opt/conda/envs/fastai/lib/python3.6/site-packages/tqdm/_tqdm.py\", line 144, in run\n",
" for instance in self.tqdm_cls._instances:\n",
" File \"/opt/conda/envs/fastai/lib/python3.6/_weakrefset.py\", line 60, in __iter__\n",
" for itemref in self.data:\n",
"RuntimeError: Set changed size during iteration\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch trn_loss val_loss accuracy \n",
" 0 3.387137 3.396025 0.547392 \n",
"\n"
]
},
{
"data": {
"text/plain": [
"[array([3.39602]), 0.547391530683386]"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learner.fit(lr, 1, wds=wd, use_clr=(32,2), cycle_len=1) # last layer is the embedding weights"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"learner.save('lm_telugu_fromscratch')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learner.load('lm_telugu_fromscratch')"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "eb366b95415c43ceac50074c42acde17",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
"<p>\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
" Widgets Documentation</a> for setup instructions.\n",
"</p>\n",
"<p>\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"</p>\n"
],
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch trn_loss val_loss accuracy \n",
" 0 3.139451 3.198013 0.558166 \n",
"\n"
]
},
{
"data": {
"text/plain": [
"[array([3.19801]), 0.5581661824732331]"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learner.fit(lrs, 1, wds=wd, use_clr=(20,10), cycle_len=1)"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"24.483832456834897"
]
},
"execution_count": 79,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# perplexity approximation\n",
"math.exp(3.198013)"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9d05b9e48fff418785f767f94282e1dc",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
"<p>\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
" Widgets Documentation</a> for setup instructions.\n",
"</p>\n",
"<p>\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"</p>\n"
],
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=15), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch trn_loss val_loss accuracy \n",
" 0 3.168684 3.187562 0.556568 \n",
" 34%|███▍ | 1937/5657 [06:46<13:00, 4.76it/s, loss=3.36]"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-56-4dbe5f88ace9>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mlearner\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlrs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwds\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mwd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0muse_clr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcycle_len\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m15\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/notebooks/courses/dl2/fastai/learner.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, lrs, n_cycle, wds, **kwargs)\u001b[0m\n\u001b[1;32m 250\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msched\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 251\u001b[0m \u001b[0mlayer_opt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_layer_opt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlrs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 252\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_gen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayer_opt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_cycle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 253\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 254\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwarm_up\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwds\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/notebooks/courses/dl2/fastai/learner.py\u001b[0m in \u001b[0;36mfit_gen\u001b[0;34m(self, model, data, layer_opt, n_cycle, cycle_len, cycle_mult, cycle_save_name, best_save_name, use_clr, use_clr_beta, metrics, callbacks, use_wd_sched, norm_wds, wds_sched_mult, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0mn_epoch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msum_geom\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcycle_len\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcycle_len\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcycle_mult\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_cycle\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 198\u001b[0m return fit(model, data, n_epoch, layer_opt.opt, self.crit,\n\u001b[0;32m--> 199\u001b[0;31m metrics=metrics, callbacks=callbacks, reg_fn=self.reg_fn, clip=self.clip, **kwargs)\n\u001b[0m\u001b[1;32m 200\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_layer_groups\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_layer_groups\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/notebooks/courses/dl2/fastai/model.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(model, data, epochs, opt, crit, metrics, callbacks, stepper, **kwargs)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0mbatch_num\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcb\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mcb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 125\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstepper\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mV\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mV\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 126\u001b[0m \u001b[0mavg_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mavg_loss\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mavg_mom\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mavg_mom\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 127\u001b[0m \u001b[0mdebias_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mavg_loss\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mavg_mom\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mbatch_num\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/notebooks/courses/dl2/fastai/model.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, xs, y, epoch)\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mraw_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcrit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreg_fn\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreg_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxtra\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mraw_loss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 53\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 54\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclip\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# Gradient clipping\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclip_grad_norm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainable_params_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclip\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.6/site-packages/torch/autograd/variable.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, retain_variables)\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0mVariable\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 166\u001b[0m \"\"\"\n\u001b[0;32m--> 167\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_variables\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 168\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.6/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(variables, grad_variables, retain_graph, create_graph, retain_variables)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m Variable._execution_engine.run_backward(\n\u001b[0;32m---> 99\u001b[0;31m variables, grad_variables, retain_graph)\n\u001b[0m\u001b[1;32m 100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"learner.fit(lrs, 1, wds=wd, use_clr=(20,10), cycle_len=15)"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [],
"source": [
"learner.save('lm_telugu_fromscratch_1_partial')"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"learner.save_encoder('adam1_enc')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"m = learner.model\n",
"# pickle.dump(m,open(f'wiki_lang.pkl','wb'))"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/plain": [
"SequentialRNN(\n",
" (0): RNN_Encoder(\n",
" (encoder): Embedding(60002, 400, padding_idx=1)\n",
" (encoder_with_dropout): EmbeddingDropout(\n",
" (embed): Embedding(60002, 400, padding_idx=1)\n",
" )\n",
" (rnns): ModuleList(\n",
" (0): WeightDrop(\n",
" (module): LSTM(400, 1150, dropout=0.105)\n",
" )\n",
" (1): WeightDrop(\n",
" (module): LSTM(1150, 1150, dropout=0.105)\n",
" )\n",
" (2): WeightDrop(\n",
" (module): LSTM(1150, 400, dropout=0.105)\n",
" )\n",
" )\n",
" (dropouti): LockedDropout(\n",
" )\n",
" (dropouths): ModuleList(\n",
" (0): LockedDropout(\n",
" )\n",
" (1): LockedDropout(\n",
" )\n",
" (2): LockedDropout(\n",
" )\n",
" )\n",
" )\n",
" (1): LinearDecoder(\n",
" (decoder): Linear(in_features=400, out_features=60002)\n",
" (dropout): LockedDropout(\n",
" )\n",
" )\n",
")"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# TEXT = pickle.load(open(f'{PATH}models/TEXT.pkl','rb'))\n",
"# m = pickle.load(open(f'{PATH}models/wiki_lang.pkl','rb'))\n",
"m[0].bs=1\n",
"m.eval()"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
"def gen_text(ss,topk):\n",
" s = [tokenize(ss)]\n",
" t = TEXT.numericalize(s)\n",
" m.reset()\n",
" pred,*_ = m(t)\n",
" pred_i = torch.topk(pred[-1], topk)[1]\n",
" return [TEXT.vocab.itos[o] for o in to_np(pred_i)]\n",
"\n",
"def gen_sentences(ss,nb_words):\n",
" result = []\n",
" s = [tokenize(ss)]\n",
" t = TEXT.numericalize(s)\n",
" m.reset()\n",
" pred,*_ = m(t)\n",
" for i in range(nb_words):\n",
" pred_i = pred[-1].topk(2)[1]\n",
" pred_i = pred_i[1] if pred_i.data[0] < 2 else pred_i[0]\n",
" result.append(TEXT.vocab.itos[pred_i.data[0]])\n",
" pred,*_ = m(pred_i[0].unsqueeze(0))\n",
" return(result)"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"TEXT = pickle.load(open(f'data/teluguwiki/models/TEXT.pkl','rb'))"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['<unk>',\n",
" '1',\n",
" 'బాగా',\n",
" 'మరియు',\n",
" 'కింది',\n",
" 'ఈ',\n",
" '31',\n",
" 'వస్తుంది',\n",
" 'విడుదల',\n",
" 'లేవు']"
]
},
"execution_count": 71,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_sentence = \"ఆయన కుటుంబం ఆయనకు వారి స్తోమత ప్రకారం వైద్యాన్ని అందించింది.\"\n",
"gen_text(test_sentence, 10)"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'1కేంద్రంtotototoగుర్తిస్తూఈకేంద్రంtoగుర్తిస్తూఈకేంద్రంకానిఉత్పత్తి:కేంద్రంఉత్పత్తి:?నుండి,ప్రాథమికలోజిల్లాచేసినవచ్చాయిఆగష్టునీపాఠశాలనుండిబాగాపాఠశాలమూడుకేంద్రం235అన్నికింది%ఇంటింటికీఅక్కడేకిఇంకాగంటలఆఫీసుఅలోపతిఅవసరాల./కంకర'"
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"''.join(gen_sentences(test_sentence, 50))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment