Skip to content

Instantly share code, notes, and snippets.

@morganmcg1
Created December 21, 2020 22:37
Show Gist options
  • Save morganmcg1/b59087db0edea2ad5e6774d20de3d6d3 to your computer and use it in GitHub Desktop.
Save morganmcg1/b59087db0edea2ad5e6774d20de3d6d3 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mon Dec 21 22:32:59 2020 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 450.66 Driver Version: 450.66 CUDA Version: 11.0 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 GeForce RTX 208... Off | 00000000:65:00.0 Off | N/A |\n",
"| 31% 38C P8 12W / 250W | 28MiB / 11019MiB | 0% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| 0 N/A N/A 1263 G /usr/lib/xorg/Xorg 9MiB |\n",
"| 0 N/A N/A 1318 G /usr/bin/gnome-shell 14MiB |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmorgan\u001b[0m (use `wandb login --relogin` to force relogin)\n"
]
}
],
"source": [
"import sys\n",
"if 'google.colab' in sys.modules:\n",
" !pip install -qq einops axial-positional-embedding fastai datasets\n",
" !pip install -qq git+git://github.com/arampacha/reformer_fastai.git\n",
" !pip install -qqq wandb\n",
"!wandb login"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from fastai.basics import *\n",
"from fastai.text.all import *\n",
"from reformer_fastai.tokenizers import SubwordTextEncoder"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load WMT Data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>de</th>\n",
" <th>en</th>\n",
" <th>is_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Buchen Sie ein Ferienhaus oder eine Ferienwohnung in Mauritius Nordküste, Mauritius Ostküste, &amp; Mauritius Westküste &amp; Inselmitte direkt beim Vermieter.</td>\n",
" <td>Why stay in a hotel when you can have a fully serviced Mauritius villa rental on Mauritius. Just contact us by the booking form or directly by telephone and we can start the process of finding the perfect holiday villa.... Find any special requirements that you have for your holiday, the Mauritius villa rental you want, trips and excursions.</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>In diesem Zusammenhang und vor dem Hintergrund der Schwachstellen, die bei dem Versuch, die griechische Initiative voranzubringen, deutlich geworden sind, ist der Hinweis wichtig, dass die Kompetenzen von Europol im Januar 2002 ausgeweitet worden sind und nun auch den illegalen Handel mit menschlichen Organen und Geweben einschließen.</td>\n",
" <td>In this context, and in view of the weaknesses detected when trying to move forward the Greek initiative, it is important to mention that Europol's competences were extended in January 2002 to include illicit trade in human organs and tissues.</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" de \\\n",
"0 Buchen Sie ein Ferienhaus oder eine Ferienwohnung in Mauritius Nordküste, Mauritius Ostküste, & Mauritius Westküste & Inselmitte direkt beim Vermieter. \n",
"1 In diesem Zusammenhang und vor dem Hintergrund der Schwachstellen, die bei dem Versuch, die griechische Initiative voranzubringen, deutlich geworden sind, ist der Hinweis wichtig, dass die Kompetenzen von Europol im Januar 2002 ausgeweitet worden sind und nun auch den illegalen Handel mit menschlichen Organen und Geweben einschließen. \n",
"\n",
" en \\\n",
"0 Why stay in a hotel when you can have a fully serviced Mauritius villa rental on Mauritius. Just contact us by the booking form or directly by telephone and we can start the process of finding the perfect holiday villa.... Find any special requirements that you have for your holiday, the Mauritius villa rental you want, trips and excursions. \n",
"1 In this context, and in view of the weaknesses detected when trying to move forward the Greek initiative, it is important to mention that Europol's competences were extended in January 2002 to include illicit trade in human organs and tissues. \n",
"\n",
" is_test \n",
"0 0 \n",
"1 0 "
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tiny_df=pd.read_feather('WMT14_TINY')\n",
"tiny_df.head(2)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"## Load full WMT14 train dataset from HuggingFace datasets\n",
"#!pip install -qq datasets\n",
"#from datasets import load_dataset\n",
"#train_dataset = load_dataset('wmt_t2t')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"df = tiny_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dataloaders"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Get train/test split for TINY_WMT14\n",
"train_split = tiny_df.loc[tiny_df.is_test==0].index.values\n",
"test_split = tiny_df.loc[tiny_df.is_test==1].index.values\n",
"\n",
"# Get Vocab for tokenizer\n",
"# !wget -q https://raw.githubusercontent.com/tensorflow/tensor2tensor/master/tensor2tensor/test_data/vocab.translate_ende_wmt32k.32768.subwords\n",
"\n",
"# Set up Sub-Word tokenizer with vocab\n",
"tok = SubwordTextEncoder(filename='./vocab.translate_ende_wmt32k.32768.subwords', add_bos=True, seq_len=256)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"LMTensorText([ 0, 4308, 105, 16, 49, 954, 11888, 33707, 5, 26494,\n",
" 16501, 5])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tok('hey is this working? <EOS>')"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# def add_eos(text):\n",
"# return text + tok.EOS\n",
"\n",
"def add_eos_id(ids, keep_size=True):\n",
" \"Adds EOS token id to the tensors. If `keep_size==True` remove the last id before appending the EOS token id\"\n",
" if keep_size: \n",
" return torch.cat([ids[:-1], LMTensorText(tok.EOS_ID).unsqueeze(0)])\n",
" else: return torch.cat([ids, LMTensorText(tok.EOS_ID).unsqueeze(0)])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 139 ms, sys: 6.88 ms, total: 146 ms\n",
"Wall time: 145 ms\n"
]
}
],
"source": [
"%%time\n",
"\n",
"splits = train_split, test_split\n",
"\n",
"# Get text lengths to enable faster init with SortedDL\n",
"df['de_lens'] = df['de'].str.len()\n",
"\n",
"en_tfms = [ColReader(\"en\"), tok, add_eos_id]\n",
"de_tfms = [ColReader(\"de\"), tok, add_eos_id]\n",
"\n",
"# Set up datsets\n",
"dsets = Datasets(df, [en_tfms, de_tfms], splits=splits)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1.13 s, sys: 442 ms, total: 1.57 s\n",
"Wall time: 1.57 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"# Get dataloader\n",
"srtd_dl = partial(SortedDL, shuffle=True, res=df['de_lens'].values[splits[0]])\n",
"dl_kwargs = [{},{'val_res': df['de_lens'].values[splits[1]]}]\n",
"\n",
"# Define padding\n",
"pad_seq2seq = partial(pad_input, pad_idx=tok.PAD_ID, pad_fields=[0,1])\n",
"\n",
"# Set up dataloaders\n",
"dls = dsets.dataloaders(bs=16, before_batch=pad_seq2seq, dl_type = srtd_dl, dl_kwargs = dl_kwargs)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>text_</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>&lt;pad&gt;* @constructor * @param {GMarker} marker * @param {String} text * @param {Number} padding */ function Tooltip(marker, text, padding){ this.marker_ = marker; this.text_ = text; this.padding_ = padding; } Tooltip.prototype = new GOverlay(); Tooltip.prototype.initialize = function(map){ var div = document.createElement(\"div\"); div.appendChild(document.createTextNode(this.text_)); div.className = 'tooltip'; div.style.position = 'absolute'; div.style.visibility = 'hidden'; div.style.backgroundColor = '#FFFFFF'; div.style.fontWeight = 'bold'; div.style.width = '200px'; div.style.height = '22px'&lt;EOS&gt;</td>\n",
" <td>&lt;pad&gt;* @constructor * @param {GMarker} marker * @param {String} text * @param {Number} padding */ function Tooltip(marker, text, padding){ this.marker_ = marker; this.text_ = text; this.padding_ = padding; } Tooltip.prototype = new GOverlay(); Tooltip.prototype.initialize = function(map){ var div = document.createElement(\"div\"); div.appendChild(document.createTextNode(this.text_)); div.className = 'tooltip'; div.style.position = 'absolute'; div.style.visibility = 'hidden'; div.style.backgroundColor = '#FFFFFF'; div.style.fontWeight = 'bold'; div.style.width = '200px'; div.style.height = '22px'&lt;EOS&gt;</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>&lt;pad&gt;Thereis, however, one point to which allusion has already been made and which it is important to underline from the outset. We have an internal market concept, we have the Lisbon and Gothenburg strategies, we have European competition policy, we have the Financial Services Action Plan and the Risk Capital Action Plan, we have Article 2, which obliges Member States to pursue a course of economic policy coordination, we have the Stability and Growth Pact, we have the euro, we have EU enlargement and hence the expansion of the internal market into a home market, and we have the eco-social market economy as our model of economic governance, which means competitiveness within a free market, promoting social cohesion while being mindful of its responsibility towards nature and people&lt;EOS&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;</td>\n",
" <td>&lt;pad&gt;Esist aber am Beginn zu betonen - es ist auch bereits angeschnitten worden -, wir haben ein Binnenmarktkonzept, wir haben die Lissabon- und Göteborg-Strategie, wir haben die europäische Wettbewerbspolitik, wir haben den Aktionsplan für Finanzdienstleistungen, den Aktionsplan für Risikokapital, wir haben den Artikel 2, der die Mitgliedstaaten zur Koordination verpflichtet, wir haben den Stabilitäts- und Wachstumspakt, wir haben den Euro, wir haben die Erweiterung der Europäischen Union und damit die Erweiterung des Binnenmarktes zum Heimatmarkt, wir haben das Ordnungsmodell der ökosozialen Marktwirtschaft, was Wettbewerbsfähigkeit in einem freien Markt bedeutet, der den sozialen Zusammenhalt fördert und sich seiner Verantwortung für Natur und Mensch bewusst ist&lt;EOS&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;&lt;pad&gt;</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dls.show_batch(max_n=2)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((16, 256), (16, 256))"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"o = dls.one_batch()\n",
"o[0].size(), o[1].size()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# !wget https://github.com/haws74516/en_ga_ds/raw/main/en_ga.zip\n",
"# "
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# import zipfile\n",
"# path = Path()\n",
"# with zipfile.ZipFile('en_ga.zip', 'r') as f:\n",
"# f.extractall(path)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"# path = Path('.')\n",
"\n",
"# df = pd.read_csv(path/'en_ga.csv', index_col=0)\n",
"# df.head()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"# %%time\n",
"# def add_eos(text):\n",
"# return f'{BOS} ' + text + f' {EOS}'\n",
"\n",
"# dblock = DataBlock(blocks=(TextBlock.from_df('en', tok_text_col='en', rules=[add_eos]),\n",
"# TextBlock.from_df('ga', tok_text_col='ga', rules=[add_eos])),\n",
"# get_x=ColReader('en'),\n",
"# get_y=ColReader('ga'), \n",
"# splitter=RandomSplitter())\n",
"\n",
"# dsets = dblock.datasets(df)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"# %%time\n",
"# pad_seq2seq = partial(pad_input, pad_fields=[0,1])\n",
"\n",
"# dl_kwargs = [{'res':df['en_len'].values[dsets.splits[0]]},\n",
"# {'val_res':df['en_len'].values[dsets.splits[1]]}]\n",
"\n",
"# dls = dsets.dataloaders(bs=8, dl_type=SortedDL, before_batch=pad_seq2seq, shuffle_train=True,\n",
"# num_workers=2, dl_kwargs=dl_kwargs)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"# dls.show_batch(max_n=4)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"# o = dls.one_batch()\n",
"# o[0].size(), o[0][0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Learner"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"# use shared vocab\n",
"enc_vocab_sz=dec_vocab_sz=tok.vocab_size\n",
"#enc_vocab_sz=dec_vocab_sz=30000\n",
"# model dim\n",
"d_model = 768"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"# enc_vocab_sz = len(dls.vocab[0])\n",
"# # dec_vocab_sz = len(dls.vocab[1])\n",
"# dec_vocab_sz=enc_vocab_sz"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"from reformer_fastai.transformer import TransformerEncDec\n",
"from reformer_fastai.core import CombineInputOutputCallback, LossTargetShiftCallback, RemoveEOSCallback\n",
"from reformer_fastai.optimizers import adafactor"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"import pdb"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"# dls.cpu(), dls.device"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(125104044, True)"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cbs = [CombineInputOutputCallback(), LossTargetShiftCallback(), RemoveEOSCallback(eos_idx=tok.EOS_ID)]\n",
"# cbs = [CombineInputOutputCallback(), LossTargetShiftCallback()] #, RemoveEOSCallback(eos_idx=tok.EOS_ID)]\n",
"\n",
"learn = Learner(dls, TransformerEncDec(enc_vocab_sz, dec_vocab_sz, d_model=d_model, heads=8, #n_enc_layers=2, n_dec_layers=1, \n",
" max_seq_len=256, pad_idx=tok.PAD_ID, tie_weights=True, shared_emb=True,\n",
" attn_dropout=0.0, ff_dropout=0.0, emb_dropout=0.0,\n",
" pos_enc='fixed'),\n",
" loss_func=CrossEntropyLossFlat(ignore_index=tok.PAD_ID), cbs=cbs, # opt_func=adafactor,\n",
" metrics=[accuracy, Perplexity(), CorpusBLEUMetric()]).to_native_fp16()\n",
"\n",
"total_params(learn.model)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"#os.environ['CUDA_LAUNCH_BLOCKING'] = '1'"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"SuggestedLRs(lr_min=0.02089296132326126, lr_steep=0.17378008365631104)"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEKCAYAAAAIO8L1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAApI0lEQVR4nO3deXxU9b3/8ddnspCFLCwJkLCFXQQBibigqNRa5bqgrda1tbVqe+1mf+1V2/5uvbet9ffrrb3V2mtttdpfXesVtYpVr1VBQSTsmwgECAmYBbKvJPn+/piJxjjBBDI5Zybv5+MxD2bOnDPncxjCO+d8v+f7NeccIiIiXQW8LkBERPxJASEiImEpIEREJCwFhIiIhKWAEBGRsBQQIiISVrzXBfSl4cOHu/Hjx3tdhohI1FizZk2Fcy4r3HsxFRDjx4+noKDA6zJERKKGme3t7j1dYhIRkbAUECIiEpYCQkREwlJAiIhIWAoIEREJSwEhIiJhKSBERKLY5pJqVuyqoL2976duUECIiESxh1fs4duPr8Os7z9bASEiEsU2FVczMzcDi0BCKCBERKJUQ0srO8pqOWF0ZkQ+XwEhIhKltuyvod3BCaMzIvL5CggRkSi1YV8VADMVECIi0tmmkmpGZSSRnZYUkc9XQIiIRKmNoQbqSFFAiIhEoerGw+yuqGfWmMyI7UMBISIShbaUVAPoDEJERD5uQ3EwICLVgwkUECIiUWlTSRVjh6aQmZIYsX0oIEREotCGfdURPXsABYSISNQ5WNdMSVWjAkJERD5u44cN1JkR3Y8CQkQkymwqrsYMZuSmR3Q/CggRkSizsbiaCcNTSUtKiOh+FBAiIlFmY3EVsyI0gmtnCggRkShSWtNEWW1zxAbo6yxiAWFmD5lZmZlt7rTsl2b2npltNLMlZpbZzbbnmdl2M9tpZrdFqkYRkWjTMYJrpOaA6CySZxAPA+d1WfYqMMM5dwLwPnB7143MLA64DzgfmA5caWbTI1iniEjU2FRSTVzAmD4qsg3UEMGAcM4tAw51WfaKc6419PIdYHSYTecBO51zhc65FuAJ4OJI1SkiEk02lVQzOXswyYlxEd+Xl20QXwVeCrM8F9jX6XVxaFlYZnajmRWYWUF5eXkflygi4h/OOTYVVzMjggP0deZJQJjZj4BW4NFwb4dZ5rr7LOfcA865fOdcflZWVl+VKCLiOweqmzhY3xLxO6g7xPfLXjoxsy8DFwCfcc6F+4+/GBjT6fVoYH9/1CYi4mebQndQx+QZhJmdB9wKXOSca+hmtdXAZDPLM7NE4Arg+f6qUUTErzb3YwM1RLab6+PASmCqmRWb2fXAb4E04FUzW29m94fWzTGzpQChRuxvAi8D24CnnHNbIlWniEi06GigTkqIfAM1RPASk3PuyjCLH+xm3f3Aok6vlwJLI1SaiEjU6WigPntadr/tU3dSi4hEgf5uoAYFhIhIVOjvBmpQQIiIRIX+bqAGBYSISFTYWNy/DdSggBAR8T3nHJtL+u8O6g4KCBERn/OigRoUECIivudFAzUoIEREfM+LBmpQQIiI+J4XDdSggBAR8TWvGqhBASEi4mteNVCDAkJExNe8aqAGBYSIiK9t2FdFvAcN1KCAEBHxtfX7qpg2Kq3fG6hBASEi4lvt7Y6NxdXMHpPpyf4VECIiPrWrvI665lZmjxniyf4VECIiPrVuXxWAziBEROTj1u+rIi0pngnDUz3ZvwJCRMSn1hdVMWt0JoGAebJ/BYSIiA81trSxvbTWs8tLoIAQEfGlzfuraWt3CggREfm49UVVAMxSQIiISGfr91WRm5lMVtogz2pQQIiI+ND6fVXMHpvpaQ0RCwgze8jMysxsc6dll5nZFjNrN7P8I2y7x8w2mdl6MyuIVI0iIn5UVttESVUjczy8vASRPYN4GDivy7LNwKXAsh5sf7ZzbrZzrtsgERGJRRv2BUdw9bKBGiA+Uh/snFtmZuO7LNsGYOZNn14RkWiwfl8l8QHzZIjvzvzaBuGAV8xsjZndeKQVzexGMysws4Ly8vJ+Kk9EJHK8HMG1M78GxHzn3InA+cDNZraguxWdcw845/Kdc/lZWVn9V6GISAS0tzs27qtm1uhMr0vxZ0A45/aH/iwDlgDzvK1IRKR/7Cqvo7a51fP2B/BhQJhZqpmldTwHziXYuC0iEvO2HqgBYKYHc1B3Fcluro8DK4GpZlZsZteb2SVmVgycCrxoZi+H1s0xs6WhTUcAb5nZBuBd4EXn3N8jVaeIiJ8UltdjBuOHeTOCa2eR7MV0ZTdvLQmz7n5gUeh5ITArUnWJiPjZ7op6Rg9J9ryBGnx4iUlEZCDbXVFP3vDBXpcBKCBERHzDOcfuinrPJgjqSgEhIuIT5XXN1DW3kqeAEBGRzgrL6wEUECIi8nG7KxQQIiISxu6KehLjA+RmJntdCqCAEBHxjcLyevKGpRII+GNAUwWEiIhP7K6o883lJVBAiIj4QmtbO0WHGsjLUkCIiEgnJVWNHG5zOoMQEZGP6+ji6peb5EABISLiC4WhLq4TsvwxzAYoIEREfGF3RR0ZyQkMSUnwupQPKSBERHwgOEhfKmb+6OIKCggREV/YXe6fQfo6KCBERDzW2NLG/uomX/VgAgWEiIjn9hwMjcHko3sgQAEhIuK5j7q4+qcHEyggREQ8t7uiDoDxw1M8ruTjFBAiIh4rrKhnVEYSKYnxXpfyMQoIERGPdXRx9RsFhIiIxxQQIiLyCZX1LVQ1HFZAiIjIxxWGGqgn+KyLK0QwIMzsITMrM7PNnZZdZmZbzKzdzPKPsO15ZrbdzHaa2W2RqlFExGs7SoMBMTk7zeNKPimSZxAPA+d1WbYZuBRY1t1GZhYH3AecD0wHrjSz6RGqUUTEUzvK6khK8M881J1FLCCcc8uAQ12WbXPObf+UTecBO51zhc65FuAJ4OIIlSki4qkdZXVMyh7sm3moO/NjG0QusK/T6+LQsrDM7EYzKzCzgvLy8ogXJyLSl3aW1vry8hL4MyDCxajrbmXn3APOuXznXH5WVlYEyxIR6Vt1za3sr25iUra/htjo4MeAKAbGdHo9GtjvUS0iIhGzqyzYQK2A6LnVwGQzyzOzROAK4HmPaxIR6XM7yjp6MEVxQJhZqpkFQs+nmNlFZnbEefHM7HFgJTDVzIrN7Hozu8TMioFTgRfN7OXQujlmthTAOdcKfBN4GdgGPOWc23K0Bygi4lc7ympJjAswdqi/Bunr0NORoZYBZ5jZEOA1oAD4InB1dxs4567s5q0lYdbdDyzq9HopsLSHtYmIRKWdpXVMyEolPs6PF3N6fonJnHMNBO9huNc5dwnBexREROQo7SyvY6JPLy9BLwLCzE4leMbwYmiZv8alFRGJIk2H2yg61ODb9gfoeUB8F7gdWOKc22JmE4DXI1aViEiM21Veh3P+HGKjQ4/OApxzbwJvAoQaqyucc9+OZGEiIrFsZ0cPphFRfgZhZo+ZWbqZpQJbge1m9oPIliYiErt2lNYRFzDGD/PfKK4denqJabpzrgZYTLB30Vjg2kgVJSIS63aW1TFuWAqJ8f7swQQ9D4iE0H0Pi4HnnHOHOcLwFyIicmQ7ymp93UANPQ+I3wN7gFRgmZmNA2oiVZSISCxraW1nz8EGXzdQQ88bqe8B7um0aK+ZnR2ZkkREYtueg/W0tTtfN1BDzxupM8zs7o5htc3sVwTPJkREpJc6ZpGbmBUDAQE8BNQCl4ceNcCfIlWUiEgs21FWi5n/A6Knd0NPdM59vtPrfzOz9RGoR0Qk5u0sq2PMkBSSE+O8LuWIenoG0Whmp3e8MLP5QGNkShIRiW07y+p834MJen4G8XXgz2aWEXpdCXw5MiWJiMSu1rZ2CsvrOXOq/2fA7Gkvpg3ALDNLD72uMbPvAhsjWJuISMwpOtRAS1s7k3ze/gC9nFHOOVcTuqMa4HsRqEdEJKZ1zCI3ZYS/74GAY5ty1PqsChGRAWJHaS2Ar+eB6HAsAaGhNkREemlHWR25mckMHuT/KXWOWKGZ1RI+CAxIjkhFIiIxbEdpHZOi4OwBPiUgnHP+v0gmIhIl2todu8rrmD9pmNel9Ih/x5kVEYkx+w410Nza7vtB+jooIERE+klHD6ZJPh+kr4MCQkSkn+woC/Zgioa7qKHnd1LHtH2HGoiPMxLiAiTGB4gPGAfrWvigpon9VY2U1zbT7hxxgQBxBoGA0drmaG1v53Cbo7XNcbitnZa2dlpa23HOkZ6cQEbo0e4cO8vqgo/yOg7WtRAwwwwCZiTGB0gbFE/qoHhSB8UxKD6OuIARMCMuABnJCWSnJZGdPojstCRyMpPIyUxmWGoiZuptLBItdpbWMSojibSkBK9L6ZGIBYSZPQRcAJQ552aElg0FngTGE5yA6HLnXGWYbfcQHD22DWh1zuVHqk6Ac3+9jMbDbcf0GQGDxPgAiXHBk7La5lZcp/5fifEBJgxPZfaYIYxIG4QD2p2jvd3R3NpOXXMr9c2t1DW3UtXQQptztLVDW3s7VQ2Hqahrpr1Lf7JB8QFyM5MZlZnEyPRkRmUkMTIjidwhyYzOTCZ3SDIpifodQMQv3i+rjZoeTBDZM4iHgd8Cf+607DbgNefcXWZ2W+j1rd1sf7ZzriKC9X3ozktn0Hz4ozOAw22OoakJjMxIJicjiez0JOICRlt78D/0NueIDxjxccGzjY7nnbW3O2qbW6lpPIxzkDskmbjA0f+239buOFjfTFlNM/urGoOP6iZKKhs5UN3Iyl0VlNY209YlRYakJDBuWCp5w1MZPyyVCVmpTB2ZRt7wVBLidIVRpL+0twevJFw1b5zXpfRYxALCObfMzMZ3WXwxcFbo+SPAG3QfEP3mkjmj+/wzAwH78BJTX4gLWPAyU1oSM3Izwq7T1u4or22mpKqB4spGSqoa2Xeokb0H61lVeJAl60o+XDcxLsDE7MFMHTGYiVmDmZgd/HP88BQGxft7CGKRaFRS1UjT4XamREkDNfR/G8QI59wBAOfcATPL7mY9B7xiZg74vXPuge4+0MxuBG4EGDt2bF/XG1XiAsbI0GWmuWF+SWk63EZheT3bS2t474Natn9Qy7u7D/Hs+v0frpMQZxw3Kp0TRmcwa3Qmc8YOYWJWqto6RI7Rhw3UCohjNt85tz8UIK+a2XvOuWXhVgyFxwMA+fn5Gv7jCJIS4piek870nPSPLW9oaaWwvJ5d5XVsO1DLhn1VPLtuP395pwiAYamJ5I8fwknjh5I/fijTR6WTGK/LUyK98X5omtFJUXIPBPR/QJSa2ajQ2cMooCzcSs65/aE/y8xsCTAPCBsQcuxSEuOZkZvBjNwMLp4dXNYeuuNzbVEl7+6uZPWeQ7y8pRQINrjPzM3gxLGZHJ+TweTQZaqkBF2aEunOjtI6RqQP6rPLzv2hvwPieYITDd0V+vO5riuYWSoQcM7Vhp6fC/x7v1YpBALG5BFpTB6RxhdPCl66K61pYs3eStYVVbK2qIpHVu6lpbUdADMYOzSFs6dmc1n+aI7PCd9OIjJQ7SyrjZo7qDtEspvr4wQbpIebWTHwE4LB8JSZXQ8UAZeF1s0B/uicWwSMAJaErnnHA4855/4eqTql50akJ7Fo5igWzRwFQEtrO3sO1rOjtI73S2vZsr+Gx1YV8fCKPczITefy/DEsnJbN6CEpHlcu4i3nHDvK6rg8f4zXpfSKORc7l+3z8/NdQUGB12UMaFUNLTy7roQnC4rZdiA4t1RuZjIn5w1lXt5Q5o4bwsSswQSOocuvSLQprmzg9P/zOj+/ZAZXn+yvbq5mtqa7e8382kgtUSozJZHr5udx3fw8tn9Qy8pdFby75xDLdpTzTKibbdqgeGaPzWTuuCF88aQxjMrQyPES26JpFrnOFBASMVNHpjF1ZBrXzc/DOUdhRT3riqpYV1TJuqIq7nltB/e9vpPFs3O56cwJUdW7Q6Q3dnb0YIqCeag7U0BIvzCz4A15WYP5wtzgjYn7DjXw4Fu7eWJ1EX9dU8w5x43gC3NHc/a0LN2sJzHl/dJahg8exJDURK9L6RUFhHhmzNAU7rjoeL61cBKPrNjDY+8W8T/bSklLimfRjFFccmIuJ+cN1U16EvV2lNVFzQiunSkgxHPDBg/ie+dO5dufmczbuw7y3LoSXti4nycL9jF1RBrXzR/P4tm5JCfqrEKiT8cYTJ8/MdfrUnpNASG+ER8X4MwpWZw5JYvGljb+tnE/f3p7D7c/s4m7XnqPK04awzWnjGPMUHWbleix9UANdc2tzBqT6XUpvaaAEF9KTozj8vwxXDZ3NKv3VPKnt3fzx7d288DyQj4zbQRfPm0cp08arstP4ntvvl8OwBmTszyupPcUEOJrZsa80D0U+6saeWxVEY+H2iomDE/lqpPHctncMWSkRM/wBTKwvLm9nBm56WSlDfK6lF7TiGsSNXIyk/n+56ay4vaF/PqLs8hMSeBnL27j5F/8Dz/46wbW76silm78lOhX03SYNUWVnDkl+s4eQGcQEoUGxcdxyZzRXDJnNFv31/CXVXt5dl0Jf11TzLSRaVxx0hgWz8klMyW6uhRK7Fmxs4K2dseZU7qb2cDfdAYhUW16Tjp3XjKTVT/8DD+/ZAYJcQHu+NtW5t35Grc8uZ41ew/prEI88+b7FaQNimfO2EyvSzkqOoOQmJCWlMDVJ4/j6pPHsbmkmidWF/Hsuv0sWVfCtJFpXH3KOC46IUdtFdJvnHMse7+c+ZOGR+30vtFZtcgRzMjN4GeLg2cVv7h0JnEB438/u5m5P3uVax9cxWOriqioa/a6TIlxu8rrKKlq5Myp0dn+ADqDkBiWOiieK+eN5YqTxrCppJoXNx3g75s/4IdLNvHjZzdxznEjuP70PObpbm2JgDe2B7u3LojSBmpQQMgAYGacMDqTE0Znctt509h2oJbnN+znidVFvLK1lBm56Vx/eh7/NDNHU6lKn3nz/XImZw8mNzN6RyvWT4MMKGbG9Jx0bjt/Gitv+ww/WzyDhpY2bnlyA6fd9Q/ufmU7H1Q3eV2mRLnGljZW7T4Utd1bO+gMQgas5MQ4rjllHFfNG8ubO8r5fyv3cu/rO7nvjV187vgRfO2MCZw4dojXZUoUemf3QVpa26O6/QEUECIEAsbZU7M5e2o2RQcb+MuqvTy5eh9LN33ASeOHcNOCiSyclq1Z8KTH3txeTlJCgJPGD/W6lGOiS0winYwdlsIPFx3HitsW8q8XTGd/VRNf+3MBn/31m/zlnb00tLR6XaJEgWU7yjl1wjCSEqJ7BGIFhEgYqYPi+erpebzxg7P4zRWzSU6M48fPbuaUO1/jzqXb2HeowesSxafKapsoLK/ntInDvS7lmOkSk8gRJMQFuHh2LhfNymHN3kr+9PYeHnxrN39cXsj5M0bxtTPymKN2Culk9e5KAE7Ki+7LS6CAEOkRMyN//FDyxwdHlX1k5R4eW1XEi5sOMHfcEL46P49zjx8RtXfMSt9ZvecQyQlxHJ+T7nUpx0wBIdJLOZnJ3H7+cXxr4WT+WrCPh97ezc2PrWX44EFcnj+aK04ay9hhmtRooFq95xBzxmbGxC8L0X8EIh4ZPCier8zP443vn81D1+Uze0wm97+5iwW/fJ1rH1zFa9tKaW/XQIEDSW3TYbYdqIn63ksdInYGYWYPARcAZc65GaFlQ4EngfHAHuBy51xlmG3PA34DxAF/dM7dFak6RY5VXMBYOG0EC6eN4EB1I0+tLuaxd/dy/SMFjBuWwpdOHc9l+aNJT9JAgbFuzd5K2h3Mi4H2B4jsGcTDwHldlt0GvOacmwy8Fnr9MWYWB9wHnA9MB640s+kRrFOkz4zKSOY750zmrVsXcu+Vcxg+eBA/fWErp975Gnc8v4U9FfVelygRtHrPIeIDFrXDe3cVsTMI59wyMxvfZfHFwFmh548AbwC3dllnHrDTOVcIYGZPhLbbGqlaRfpaQlyAC2flcOGsHDYVV/Ont3fz6Kq9PLJyD5+Zls1NZ06MmcsQ8pHVuys5PjeDlMTYaN7t7zaIEc65AwChP8NNs5QL7Ov0uji0LCwzu9HMCsysoLy8vE+LFekLM0dncPcXZ/P2rQv51tmTWFtUxWX3r+TaB1exrugTV1glSjW3trG+uIqTxsVOt2c/NlKHG8+g25Y+59wDzrl851x+VlZ0j3sisS07PYnvnTuVt29dyA8XTWPL/hou+d0KvvrwatYqKKLepuJqWlrbY+L+hw79HRClZjYKIPRnWZh1ioExnV6PBvb3Q20i/SI5MY4bF0xk+b+czQ8+N5U1eyu59Hcr+Px/reClTQdoU8+nqPTunkMAMXXpsL8D4nngy6HnXwaeC7POamCymeWZWSJwRWg7kZiSOiiem8+exIrbFnLHhdMpq23iG4+u5az/eJ3HVhXR0trudYnSC6t3H2JS9mCGpiZ6XUqfiVhAmNnjwEpgqpkVm9n1wF3AZ81sB/DZ0GvMLMfMlgI451qBbwIvA9uAp5xzWyJVp4jXUgfFc13ofor7rzmRYamD+OGSTSz81Rs8tXofrW0KCr9ra3cU7K2MqbMHAHMudk5n8/PzXUFBgddliBwT5xxvvF/Or199n43F1YwblsIt50zhwlk5xGnIcV/aur+GRfcs5+7LZ3HpiaO9LqdXzGyNcy4/3Ht+bKQWGdDMgvNTPHfzfP74pXxSEuP57pPrWfSb5by6tZRY+qUuVhTsjb32B1BAiPiWmXHO9BG8+K3TuffKObS0tXPDnwu49L9WsDrUICr+8O7uQ4zKSGL0kOidfzocBYSIzwUCxoWzcnjllgXcdelM9lc1ctn9K7n50bUUHdS8FF5rb3es3HWQk/OGYhZblwAVECJRIiEuwBXzxvL698/ilnOm8I/3yjjn7jf5xdJt1DQd9rq8AWtDcRUH61s4e1q4+36jmwJCJMqkJMbznXMm88YPzuLi2Tk8sLyQhf/xBo+/W6R7KDzw+ntlBAzOnBJ7N+oqIESi1Ij0JH552Syev/l08oancvszm/ine5azctdBr0sbUP6xvYy544aQmRI79z90UECIRLmZozN46qZT+e1Vc6htauXKP7zDnUu36Ua7flBa08TmkpqYvLwECgiRmGBmXHBCDq/9rzO59pRxPLCskMt+v5J9h9SIHUmvvxccLWihAkJE/C4pIY6fLp7B764+kcLyOhbds5wXNx7wuqyY9Y/3ysjJSGLqiDSvS4kIBYRIDFo0cxRLv30GE7IGc/Nja7n50bWU1zZ7XVZMaW5t462dFSw8Ljvmurd2UECIxKgxQ1N4+uun8v1zp/Dq1lI+++s3WbKuWHdi95FVhYdoaGmL2ctLoIAQiWkJcQG+uXAyL3472NPplic38LVHCqio09nEsfrHe2UMig9w6oThXpcSMQoIkQFg8og0nv76afz4n45j+c4Kzv/Ncpa9rxkYj5Zzjn+8V8b8ScNJTozzupyIUUCIDBBxAeNrZ0zguZvnk5mcwJceelfdYY/SrvJ6ig41xGz31g4KCJEB5rhR6Tz/zdO55pSxwe6w969gf1Wj12VFlVjv3tpBASEyACUnxvGzxTO5/5q57Cqv58J73+KdQt2B3RPOOV7YuJ9pI9PIzYyt0Vu7UkCIDGDnzRjJszfPJyMlgav/uIo/vb1bvZw+RcHeSjYUV3P1yWO9LiXiFBAiA9yk7ME8d/N8zp6azb/9bSvf+MtaPqhu8ros3/rDskIyUxL4wtwxXpcScQoIESEtKYEHrp3LredN4/XtZXzmV2/w0Fu7NR92F7sr6nl1WynXnDwupnsvdVBAiAgQnJjoG2dN5JVbFpA/fij//sJWLvrt22wqrva6NN948K1CEgIBvnTaOK9L6RcKCBH5mHHDUnn4Kyfxu6tPpKKumUt+9zb3vb5zwM81UVnfwtNrilk8J4fstCSvy+kXCggR+QQzY9HMUbxyywI+N2Mkv3x5O1c+8A7FlQN3dNi/vLOXpsPtfO2MCV6X0m8UECLSrcyURH575Rx+ddksth6o4fz/XM7TawbeeE5Nh9t4ZOVezpySxZQYHbk1HAWEiByRmfH5uaN56TtnMG1UGt//6wa++vBqDlQPnJvrnl1XQkVdMzcuGDhnD+BRQJjZd8xss5ltMbPvhnn/LDOrNrP1oce/elCmiHQyZmgKT954Kj+5cDorCw9y7t3LeHJ1UcyfTew71MCdS7cxe0wmp00c5nU5/arfA8LMZgA3APOAWcAFZjY5zKrLnXOzQ49/79ciRSSsQMD4yvw8Xv7uAqbnpHPrf29i8e9WsGJXhdelRURzaxv//OhaHHDPFXNidt6H7nhxBnEc8I5zrsE51wq8CVziQR0icpTGDUvl8RtO4f9+4QTKapq46g+ruPbBVWwuia0usT99YSubSqr51WWzGDssxety+p0XAbEZWGBmw8wsBVgEhLsl8VQz22BmL5nZ8d19mJndaGYFZlZQXq7hi0X6SyBgXJ4/hte/fxY/WnQcm0qqueDet/jDskKvS+sTz60v4S/vFHHjggmce/xIr8vxhHlx/dDMrgduBuqArUCjc+6WTu+nA+3OuTozWwT8xjkX7jLUx+Tn57uCgoJIlS0iR1DTdJjb/3sTL246wPWn5/GjRccRCETnJZmdZbVc9Nu3OT4nncduOIWEuNjtz2Nma5xz+eHe8+SonXMPOudOdM4tAA4BO7q8X+Ocqws9XwokmFnsTtskEgPSkxK498o5fGX+eB58azffeXI9za1tXpfVa02H27j50XUkJ8Rx75UnxnQ4fJp4L3ZqZtnOuTIzGwtcCpza5f2RQKlzzpnZPIJBprGIRXwuEDD+9YLpjExP4hcvvUdpTRPfPWcyp+QNi5qziZ++sJXtpbU8/JWTGJkxMO6Y7o4nAQH8t5kNAw4DNzvnKs3s6wDOufuBLwDfMLNWoBG4wsV6XzqRGGFm3HTmRLLTB/HjJZu56g+rGJmexIWzRnHx7FyOz0n3bW+gpZsO8OiqIm5aMIGzpsb2ZEA94UkbRKSoDULEXxpb2vifbaU8t76EN7aX09rumDoijUtPzGXxnFxGpPvnN/R9hxpYdM9yJmQN5q83nUpi/MC4tHSkNggFhIj0i8r6Fl7YdIBn1hazrqiKgMH8ScO59MRcPnf8SFISvbqgAYfb2rn89yvZWVrHi98+Y0B1aVVAiIivFJbX8czaEpasK6GkqpGUxDg+d/xIrjllLHPHDe23Olpa23l2fQkPLCtkZ1kd9145hwtn5fTb/v1AASEivtTe7li95xBL1pXw4qYD1Da18pX547n1vGkkJURmQp6mw23sKq/j7Z0VPPTWHj6oaeK4Uel8a+EkFs0cFZF9+pkCQkR8r7Gljbte2sYjK/cyOXsw/3nFbI7PyTimz2xta2frgRre3X2Igj2VbC+tZe/BejqmtjhlwlC+fuZEzpyS5duG80hTQIhI1Hhjexn/8vRGKhta+MLcMUwbmcbErMFMyEplZHrSJ7rLVjW0ULCnkrVFlVQ2tNDQ0kZjSxs1TYfZVFxNfUvwXoyxQ1OYPiqdKSMGM2VkGseNSmdi1mAvDtFXFBAiElUq61u4429beG1bGXXNrR8uT4gzstOSyE4fRNbgQew92MD20loA4gPGkNREUhLjSE6IIyUxjuNzMpiXN5R5eUN91WPKT44UEN51GxAR6caQ1ER+c8UcnHOU1Tazq7yOwvJ6SqoaKa1porSmicKKekZlJHHBCaM4KW8os8dkRqzdYqBSQIiIb5kZI9KTGJGexGkTNdpOfxsYd4KIiEivKSBERCQsBYSIiISlgBARkbAUECIiEpYCQkREwlJAiIhIWAoIEREJK6aG2jCzaj4+v3UGUN3D58OBiqPYbefP6u064ZZ3XfZpdXdeFu3HcLT1H6m+nqxzpHqP9Lov/x0dqb5Pe78vvoPOz6P1GPSzcOT6ultnnHMuK+xazrmYeQAPdPf6054DBX2xz96sE255b4+hy7KoPoajrb+vj6Gnr/vy31FPjiGS30EsHIN+Fo79GLo+Yu0S09+O8Lonz/tin71ZJ9zy3h7Dsdbf088YSMfQ09d9+e+oJ58Rye+gJ/vvCS+PwW//jsIt8/sxfExMXWI6FmZW4LoZ0TBaRPsxRHv9oGPwi2g/Br/UH2tnEMfiAa8L6APRfgzRXj/oGPwi2o/BF/XrDEJERMLSGYSIiISlgBARkbAUECIiEpYCogfM7Awzu9/M/mhmK7yup7fMLGBmPzeze83sy17XczTM7CwzWx76Hs7yup6jZWapZrbGzC7wupajYWbHhb6Dp83sG17X01tmttjM/mBmz5nZuV7XczTMbIKZPWhmT0d6XzEfEGb2kJmVmdnmLsvPM7PtZrbTzG470mc455Y7574OvAA8Esl6u+qL+oGLgVzgMFAcqVq700fH4IA6IInoPQaAW4GnIlPlkfXRz8K20M/C5UC/dsPso/qfdc7dAFwHfDGC5YbVR8dQ6Jy7PrKVfrSzmH4AC4ATgc2dlsUBu4AJQCKwAZgOzCQYAp0f2Z22ewpIj7b6gduAm0LbPh2N3wEQCG03Ang0So/hHOAKgv85XRCNxxDa5iJgBXBVNNYf2u5XwInR+h2Etov4z3I8Mc45t8zMxndZPA/Y6ZwrBDCzJ4CLnXO/AMKe+pvZWKDaOVcTyXq76ov6zawYaAm9bItguWH11XcQUgkMikihR9BH38PZQCrBH/5GM1vqnGuPbOUf6avvwTn3PPC8mb0IPBbBkrvuty++AwPuAl5yzq2NcMmf0Mc/CxEX8wHRjVxgX6fXxcDJn7LN9cCfIlZR7/S2/meAe83sDGBZJAvrhV4dg5ldCnwOyAR+G9HKeq5Xx+Cc+xGAmV0HVPRnOBxBb7+Hs4BLCYb00kgW1kO9/Vn4FsEzuQwzm+Scuz+SxfVQb7+DYcDPgTlmdnsoSCJioAaEhVl2xDsGnXM/iVAtR6NX9TvnGggGnJ/09hieIRh0ftLrf0cAzrmH+76Uo9bb7+EN4I1IFXMUelv/PcA9kSvnqPT2GA4CX49cOR+J+UbqbhQDYzq9Hg3s96iWoxHt9YOOwS+i/RiivX7w8TEM1IBYDUw2szwzSyTYcPi8xzX1RrTXDzoGv4j2Y4j2+sHPx9Dfrfge9Bp4HDjAR108rw8tXwS8T7D3wI+8rjNW69cx+OcR7ccQ7fVH4zFosD4REQlroF5iEhGRT6GAEBGRsBQQIiISlgJCRETCUkCIiEhYCggREQlLASExzczq+nl/fTJfSGj+i2ozW2dm75nZf/Rgm8VmNr0v9i8CCgiRXjGzI45f5pw7rQ93t9w5NweYA1xgZvM/Zf3FBEeKFekTA3WwPhnAzGwicB+QBTQANzjn3jOzC4EfExyT/yBwtXOu1MzuAHKA8UCFmb0PjCU4fv9Y4D9dcBA4zKzOOTc4NOrpHUAFMANYA1zjnHNmtgi4O/TeWmCCc67bYZ2dc41mtp7gqJ+Y2Q3AjaE6dwLXArMJztNwppn9GPh8aPNPHOfR/r3JwKMzCBmIHgC+5ZybC3wf+F1o+VvAKaHf2p8A/qXTNnMJjtF/Vej1NILDj88DfmJmCWH2Mwf4LsHf6icA880sCfg9cL5z7nSC/3kfkZkNASbz0VDtzzjnTnLOzQK2ERyuYQXB8Xt+4Jyb7ZzbdYTjFOkRnUHIgGJmg4HTgL8G544BPpqAaDTwpJmNIvjb+e5Omz7vnGvs9PpF51wz0GxmZQRnuus6Feq7zrni0H7XEzwDqQMKnXMdn/04wbOBcM4ws43AVOAu59wHoeUzzOxnBOfGGAy83MvjFOkRBYQMNAGgyjk3O8x79wJ3O+ee73SJqEN9l3WbOz1vI/zPUrh1wo39353lzrkLzGwK8JaZLXHOrQceBhY75zaEJh86K8y2RzpOkR7RJSYZUFxwytjdZnYZBKegNLNZobczgJLQ8y9HqIT3gAmdpp384qdt4Jx7H/gFcGtoURpwIHRZ6+pOq9aG3vu04xTpEQWExLoUMyvu9Pgewf9UrzezDcAW4OLQuncQvCSznGADcp8LXab6Z+DvZvYWUApU92DT+4EFZpYH/G9gFfAqwcDp8ATwg1DX2Il0f5wiPaLhvkX6mZkNds7VWbBx4D5gh3Pu117XJdKVziBE+t8NoUbrLQQva/3e23JEwtMZhIiIhKUzCBERCUsBISIiYSkgREQkLAWEiIiEpYAQEZGwFBAiIhLW/wcbVFlegaTLJgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmorgan\u001b[0m (use `wandb login --relogin` to force relogin)\n"
]
},
{
"data": {
"text/html": [
"\n",
" Tracking run with wandb version 0.10.12<br/>\n",
" Syncing run <strong style=\"color:#cdcd00\">triple_sharing_wmt_tiny</strong> to <a href=\"https://wandb.ai\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
" Project page: <a href=\"https://wandb.ai/fastai_community/reformer-fastai\" target=\"_blank\">https://wandb.ai/fastai_community/reformer-fastai</a><br/>\n",
" Run page: <a href=\"https://wandb.ai/fastai_community/reformer-fastai/runs/14eier3y\" target=\"_blank\">https://wandb.ai/fastai_community/reformer-fastai/runs/14eier3y</a><br/>\n",
" Run data is saved locally in <code>/home/morgan/ml/projects/reformer_fastai/nbs/exploration/wandb/run-20201221_223327-14eier3y</code><br/><br/>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<h1>Run(14eier3y)</h1><p></p><iframe src=\"https://wandb.ai/fastai_community/reformer-fastai/runs/14eier3y\" style=\"border:none;width:100%;height:400px\"></iframe>"
],
"text/plain": [
"<wandb.sdk.wandb_run.Run at 0x7f3ce4a8ed50>"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import wandb\n",
"from fastai.callback.wandb import *\n",
"\n",
"WANDB_NAME = 'triple_sharing_wmt_tiny'\n",
"GROUP = 'TEST'\n",
"NOTES = 'Tripe weight sharing with the WMT_TINY dataset, fixed positional embeddings'\n",
"CONFIG = {}\n",
"TAGS =['enc-dec','test','wmt14_tiny']\n",
"\n",
"wandb.init(reinit=True, project=\"reformer-fastai\", entity=\"fastai_community\", \n",
" name=WANDB_NAME, group=GROUP, notes=NOTES, tags=TAGS)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
" </style>\n",
" <progress value='0' class='' max='3' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 0.00% [0/3 00:00<00:00]\n",
" </div>\n",
" \n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>perplexity</th>\n",
" <th>corpus_bleu</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table><p>\n",
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
" </style>\n",
" <progress value='1620' class='' max='2870' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 56.45% [1620/2870 02:57<02:16 6.9144]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(3, 1e-4, div=5, cbs=WandbCallback(log_preds=False, log_model=False))"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"# # exports\n",
"# class myRemoveEOSCallback(Callback):\n",
"# \"\"\"\n",
"# Shift the target presented to the model during training to remove the \"eos\" token as \n",
"# we don't want the model to learn to translate EOS when it sees EOS.\n",
" \n",
"# In practice we actually mask the EOS token as due to batching the last token will often be a <pad> token,\n",
"# not EOS\n",
"# \"\"\"\n",
"# def __init__(self, eos_idx): self.eos_idx=eos_idx\n",
"# def before_batch(self): \n",
"# eos_mask=(self.learn.xb[1]!=self.eos_idx)\n",
"# sz=torch.tensor(self.learn.xb[1].size())\n",
"# # If ids contain eos token ids, do masking\n",
"# if eos_mask.sum() < sz[0]*sz[1]: \n",
"# sz[1]=sz[1]-1\n",
"# self.learn.xb = (self.learn.xb[0], self.learn.xb[1][eos_mask].view((sz[0],sz[1])))\n",
"# return\n",
"# else: return"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(125104044, True)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# cbs = [CombineInputOutputCallback(), LossTargetShiftCallback(), RemoveEOSCallback(eos_idx=tok.EOS_ID)]\n",
"# # cbs = [CombineInputOutputCallback(), LossTargetShiftCallback()] #, RemoveEOSCallback(eos_idx=tok.EOS_ID)]\n",
"\n",
"# learn = Learner(dls, TransformerEncDec(enc_vocab_sz, dec_vocab_sz, d_model=d_model, heads=8, \n",
"# max_seq_len=256, pad_idx=tok.PAD_ID, tie_weights=True, shared_emb=True,\n",
"# attn_dropout=0.0, ff_dropout=0.0, emb_dropout=0.0,\n",
"# pos_enc='fixed'),\n",
"# loss_func=CrossEntropyLossFlat(ignore_index=tok.PAD_ID), opt_func=adafactor, cbs=cbs,\n",
"# metrics=[accuracy, Perplexity(), CorpusBLEUMetric()]).to_native_fp16()\n",
"\n",
"# total_params(learn.model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment