Skip to content

Instantly share code, notes, and snippets.

@morganmcg1
Created February 21, 2021 22:08
Show Gist options
  • Save morganmcg1/9ad869d73c7dcbe15a27d7e261aa3fb5 to your computer and use it in GitHub Desktop.
Save morganmcg1/9ad869d73c7dcbe15a27d7e261aa3fb5 to your computer and use it in GitHub Desktop.
Testing out ADMIN with BTE, trained SentencePieceUnigram and PreTrained T5 Tokenizers
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"#default_exp xtransformer"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"from nbdev.showdoc import *"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"#!pip install datasets"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"from fastai.basics import *\n",
"from fastai.text.all import *\n",
"from fastai.callback import *\n",
"from fastai.callback.wandb import *\n",
"\n",
"from transformers_sandbox.core import *\n",
"from transformers_sandbox.layers import *\n",
"from transformers_sandbox.attention import *\n",
"from transformers_sandbox.tokenizers import ByteTextTokenizer, SPUniTokenizer\n",
"from transformers_sandbox.metrics import bpc\n",
"from transformers_sandbox.transformer import LMMixin, EncDecMixin, TransformerLM\n",
"from transformers_sandbox.config import ConfigBase, update_sig\n",
"from transformers_sandbox.xtransformer import *\n",
"\n",
"import wandb\n",
"from datasets import load_dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ADMIN Init\n",
"\n",
"> ADMIN Initialisation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"📕 **Paper**: https://arxiv.org/abs/2004.08249\n",
"\n",
"Transformers have proved effective in many NLP tasks. However, their training requires non-trivial efforts regarding designing cutting-edge optimizers and learning rate schedulers carefully (e.g., conventional SGD fails to train Transformers effectively). Our objective here is to understand what complicates Transformer training from both empirical and theoretical perspectives. \n",
"\n",
"Our analysis reveals that unbalanced gradients are not the root cause of the instability of training. Instead, we identify an amplification effect that influences training substantially -- for each layer in a multi-layer Transformer model, heavy dependency on its residual branch makes training unstable, since it amplifies small parameter perturbations (e.g., parameter updates) and results in significant disturbances in the model output. \n",
"\n",
"Yet we observe that a light dependency limits the model potential and leads to inferior trained models. Inspired by our analysis, we propose Admin (**Ad**aptive **m**odel **in**itialization) to stabilize stabilize the early stage's training and unleash its full potential in the late stage. Extensive experiments show that Admin is more stable, converges faster, and leads to better performance. Implementations are released at: [this https URL](https://github.com/LiyuanLucasLiu/Transforemr-Clinic)."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class AdminResidual(Module):\n",
" def __init__(self, sublayer, d_model):\n",
" self.sublayer = sublayer\n",
" self.w = torch.nn.Parameter(torch.ones(d_model))\n",
" def forward(self, x, *args, **kwargs):\n",
" return x*self.w + self.sublayer(x, *args, **kwargs)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class TransformerEncoderBlockAdmin(Module):\n",
" \"\"\"\n",
" Bacis transformer encoder block. Consists of multi-head attention and positional \n",
" feedforward layers\n",
" \"\"\"\n",
" def __init__(self,\n",
" d_model:int, \n",
" n_heads:int = 8, \n",
" d_ff:int = None, \n",
" attn_dropout:float = 0.1,\n",
" ff_dropout:float = 0.1,\n",
" causal:bool = False, \n",
" attn_bias:bool = False, \n",
" prenorm:bool=False,\n",
" shared_qk:bool=False):\n",
" store_attr('attn_dropout') # mb separate argument attn_post_dropout\n",
" \n",
" self.attn = PostNorm(d_model, AdminResidual(Attention(d_model, n_heads=n_heads, causal=causal, dropout=attn_dropout, bias=attn_bias, shared_qk=shared_qk), d_model))\n",
" self.ff = PostNorm(d_model, AdminResidual(FeedForward(d_model, d_ff=d_ff, dropout=ff_dropout), d_model))\n",
" \n",
" def forward(self, x, mask=None):\n",
" out = self.attn(x, mask=mask)\n",
" return self.ff(out)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# bs = 4\n",
"# sl = 128\n",
"# d = 64\n",
"# x = torch.randn(bs, sl, d)\n",
"# m = TransformerEncoderBlockAdmin(d)\n",
"# out = m(x)\n",
"# assert (out.size() == (bs, sl, d))\n",
"# out.shape"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class TransformerEncoderAdmin(Module):\n",
" \"\"\"Stack of TransformerEncoderBlocks\"\"\"\n",
" def __init__(self, \n",
" d_model, \n",
" n_layers=6, \n",
" n_heads=8, \n",
" d_ff=None,\n",
" ff_dropout=0.1, \n",
" attn_dropout=0.1,\n",
" attn_bias=False,\n",
" causal=False, \n",
" prenorm=False,\n",
" shared_qk:bool=False,\n",
" final_norm=None):\n",
" store_attr('d_model')\n",
" self.layers = nn.ModuleList([]) \n",
" for _ in range(n_layers):\n",
" self.layers.append(TransformerEncoderBlockAdmin(d_model, n_heads, causal=causal, \n",
" d_ff=d_ff, attn_dropout=attn_dropout, ff_dropout=ff_dropout, \n",
" prenorm=prenorm, attn_bias=attn_bias, shared_qk=shared_qk))\n",
" self.norm = None if final_norm is None else final_norm(d_model)\n",
" \n",
" def forward(self, x, mask=None):\n",
" for layer in self.layers: x = layer(x, mask=mask)\n",
" if self.norm is not None: x = self.norm(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class TransformerLMAdmin(Module, LMMixin):\n",
" \"\"\"\n",
" tmp\n",
" Basic Transformer for language modelling\n",
" \n",
" Parameters:\n",
" * vocab_sz: int\n",
" * d_model: int - inner dimension of the model\n",
" * n_layers: int (default: 6) \n",
" * n_heads: int (default: 8)\n",
" * d_ff: int - inner dimension of the pointwise FeedForward net, if None defaults to 4*d_model\n",
" * attn_dropout: float - attention dropout\n",
" * ff_dropout: float - feed-forward dropout\n",
" * emb_dropout: float - embedding dropout\n",
" * causal: bool (default: True) - if True does causal masking automatically\n",
" * max_seq_len: int (default: 512)\n",
" * tie_weights: bool - if True target embedding weights are used for computation output projection\n",
" * prenorm: bool - wether to use PreNorm or PostNorm\n",
" * attn_bias: bool - wether to allow biases in attention projection layers\n",
" * pad_idx: int - padding token id, required for autogeneration of padding mask\n",
" * pos_enc: str from {'absolute', 'fixed', 'axial'} - type of positional encoding to use\n",
" * axial_shape: tuple - [optional] should be factors of max_seq_len\n",
" * axial_emb_dims: tuple - [optional] axial embedding components, should sum to d_model\n",
" Inputs:\n",
" * x - input ids, shape [bs, sl]\n",
" * mask - optional boolean mask, shape [bs, sl]\n",
" Returns:\n",
" * logits - target token logits, shape [bs, sl, vocab_sz]\n",
" \"\"\"\n",
" def __init__(self, \n",
" vocab_sz:int, \n",
" d_model:int, \n",
" n_layers:int=6,\n",
" n_heads:int=8,\n",
" d_ff:int=None,\n",
" attn_dropout:float=0.1,\n",
" ff_dropout:float=0.1,\n",
" emb_dropout:float=0.1,\n",
" tie_weights:bool=True,\n",
" causal:bool=True,\n",
" pos_enc:str='absolute',\n",
" max_seq_len:int=512,\n",
" axial_shape:tuple=None,\n",
" axial_emb_dims:tuple=None,\n",
" pad_idx:int=None,\n",
" prenorm:bool=False,\n",
" attn_bias:bool=False,\n",
" shared_qk:bool=False):\n",
" store_attr()\n",
" self.emb = TransformerEmbedding(vocab_sz, d_model, max_seq_len, dropout=emb_dropout, \n",
" pos_enc=pos_enc, axial_shape=axial_shape, \n",
" axial_emb_dims=axial_emb_dims)\n",
" final_norm = None\n",
" self.encoder = TransformerEncoderAdmin(d_model, n_layers, n_heads, causal=causal, d_ff=d_ff,\n",
" attn_dropout=attn_dropout, ff_dropout=ff_dropout,\n",
" prenorm=prenorm, attn_bias=attn_bias,\n",
" shared_qk=shared_qk, final_norm=final_norm)\n",
" self.proj = nn.Linear(d_model, vocab_sz)\n",
" if tie_weights: self.proj.weight = self.emb.emb.weight\n",
" \n",
" def forward(self, x, mask=None):\n",
" x = self.emb(x)\n",
" x = self.encoder(x, mask=mask)\n",
" return self.proj(x)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"@patch(cls_method=True)\n",
"def from_config(cls:TransformerLMAdmin, config):\n",
" return cls(**config.dict())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Reusing dataset wikitext (/home/morgan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91)\n",
"Loading cached processed dataset at /home/morgan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-0b8034a134b9567b.arrow\n",
"Reusing dataset wikitext (/home/morgan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91)\n",
"Loading cached processed dataset at /home/morgan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-d23112255b6f8a24.arrow\n"
]
}
],
"source": [
"train_ds = load_dataset('wikitext', name='wikitext-2-raw-v1', split='train')\n",
"\n",
"train_ds = train_ds.filter(lambda x: x['text'] != '')\n",
"train_df = train_ds.data.to_pandas()\n",
"\n",
"valid_ds = load_dataset('wikitext', name='wikitext-2-raw-v1', split='validation')\n",
"\n",
"valid_ds = valid_ds.filter(lambda x: x['text'] != '')\n",
"valid_df = valid_ds.data.to_pandas()\n",
"\n",
"df = pd.concat([train_df, valid_df])\n",
"df['is_valid'] = [False]*len(train_df) + [True]*len(valid_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tokenizer + DataLoaders"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ByteTextTokenizer"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"tok = ByteTextTokenizer(is_lm=True, add_bos=False, add_eos=False)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"#Run setup\n",
"n_epochs = 10\n",
"bs = 16\n",
"sl = 512\n",
"n_layers = 6\n",
"pad_id = tok.pad_token_id\n",
"vocab_size = tok.vocab_size"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"df['toks'] = df['text'].apply(tok)\n",
"df['lens'] = df['toks'].apply(len)\n",
"splits = ColSplitter()(df)\n",
"tfms = [attrgetter(\"text\"), tok]\n",
"dsets = Datasets(df, [tfms], splits=splits, dl_type=LMDataLoader)\n",
"dl_kwargs = [{'lens':df['lens'].values[splits[0]]},\n",
" {'val_lens':df['lens'].values[splits[1]]}]\n",
"dls = dsets.dataloaders(bs=bs, seq_len=sl, dl_kwargs=dl_kwargs, shuffle_train=True, n_workers=2)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>text_</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>The film is set in 2034 , two years after the events of 2nd GIG . Togusa is now the team leader for Public Security Section 9 , which has increased considerably in size . Section 9 deals with a series of complicated incidents , including the assassination of Ka Rum , a former dictator of the Siak Republic , which leads to a terrorist plot using children as vectors for a cybernetic virus . Investigations reveal that a hacker nicknamed the \" The Puppeteer \" is behind the entire series of events . \\n Mozambiqu</td>\n",
" <td>The film is set in 2034 , two years after the events of 2nd GIG . Togusa is now the team leader for Public Security Section 9 , which has increased considerably in size . Section 9 deals with a series of complicated incidents , including the assassination of Ka Rum , a former dictator of the Siak Republic , which leads to a terrorist plot using children as vectors for a cybernetic virus . Investigations reveal that a hacker nicknamed the \" The Puppeteer \" is behind the entire series of events . \\n Mozambique</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>združenje ) its association with Yugoslavia on 25 June 1991 . The European Economic Community and the Conference on Security and Cooperation in Europe urged Croatian authorities to place a three @-@ month moratorium on the decision . Croatia agreed to freeze its independence declaration for three months , initially easing tensions . Nonetheless , the Croatian War of Independence escalated further . On 7 October , the eve of expiration of the moratorium , the Yugoslav Air Force attacked Banski dvori , the m</td>\n",
" <td>druženje ) its association with Yugoslavia on 25 June 1991 . The European Economic Community and the Conference on Security and Cooperation in Europe urged Croatian authorities to place a three @-@ month moratorium on the decision . Croatia agreed to freeze its independence declaration for three months , initially easing tensions . Nonetheless , the Croatian War of Independence escalated further . On 7 October , the eve of expiration of the moratorium , the Yugoslav Air Force attacked Banski dvori , the ma</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(16, None)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#collapse_output\n",
"dls.bs, dls.show_batch(max_n=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"#config = XConfig(vocab_sz=vocab_size, d_model=512, n_layers=n_layers, max_seq_len=512, pad_idx=pad_id)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"class CharLMConfig(ConfigBase):\n",
" _model = TransformerLM\n",
" _d = {\n",
" 'vocab_sz':256,\n",
" 'd_model':512,\n",
" 'n_layers':6,\n",
" 'n_heads':8,\n",
" 'd_ff':4096,\n",
" 'attn_dropout':0.1,\n",
" 'ff_dropout':0.1,\n",
" 'emb_dropout':0.1,\n",
" 'tie_weights':True,\n",
" 'causal':True,\n",
" 'pos_enc':'absolute',\n",
" 'max_seq_len':512,\n",
" 'axial_shape':None,\n",
" 'axial_emb_dims':None,\n",
" 'pad_idx':None,\n",
" 'prenorm':False,\n",
" 'attn_bias':False,\n",
" 'shared_qk':False,\n",
" }\n",
" @update_sig(_d)\n",
" def __init__(self, **kwargs):\n",
" super().__init__(**kwargs)\n",
"\n",
"config = CharLMConfig(vocab_sz=vocab_size, d_model=512, n_layers=n_layers, \n",
" max_seq_len=512, pad_idx=pad_id)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Learner"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(dls, TransformerLMAdmin.from_config(config),\n",
" loss_func=CrossEntropyLossFlat(ignore_index=pad_id), #opt_func=Adam,\n",
" cbs = [\n",
" #GradientClip(1.0),\n",
" SaveModelCallback(with_opt=True)],\n",
" metrics=[accuracy, perplexity, bpc]).to_fp16()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"class BreakFitCallback(Callback):\n",
" order=-1\n",
" \"Cancels fit after one batch before weight update\"\n",
" def before_step(self):\n",
" self.model.zero_grad(set_to_none=True)\n",
" raise CancelStepException\n",
" def after_step(self):\n",
" raise CancelBatchException\n",
" def after_batch(self):\n",
" print('Fit canceled')\n",
" raise CancelFitException"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"def res_submodules(model):\n",
" return [m.sublayer for m in learn.model.modules() if isinstance(m, AdminResidual)]\n",
"\n",
"def res_modules(model):\n",
" return [m for m in learn.model.modules() if isinstance(m, AdminResidual)]"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"learn.add_cb(ActivationStats(modules=res_submodules(learn.model)));"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"12"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(learn.activation_stats.modules)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train_eval -10\n",
"recorder 50\n",
"progress 60\n",
"save_model 60\n",
"mixed_precision 10\n",
"activation_stats -20\n"
]
}
],
"source": [
"for cb in learn.cbs:\n",
" print(cb.name, cb.order)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>perplexity</th>\n",
" <th>bpc</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fit canceled\n"
]
}
],
"source": [
"with learn.added_cbs(BreakFitCallback()), learn.removed_cbs(SaveModelCallback):\n",
" learn.fit(1, 1e-3)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#12) [{'mean': 0.00481663690879941, 'std': 0.3442692756652832, 'near_zero': 0.6015446980794271},{'mean': -0.0025428251828998327, 'std': 0.3987179100513458, 'near_zero': 0.5950142542521158},{'mean': 0.004401894751936197, 'std': 0.19933182001113892, 'near_zero': 0.6423584620157877},{'mean': 0.003396979533135891, 'std': 0.4013838768005371, 'near_zero': 0.5894660949707031},{'mean': 0.009360421448946, 'std': 0.26095426082611084, 'near_zero': 0.6073331832885742},{'mean': -0.0018820768455043435, 'std': 0.3994253873825073, 'near_zero': 0.5938955942789713},{'mean': -0.02019776962697506, 'std': 0.3356610834598541, 'near_zero': 0.6348047256469727},{'mean': 0.013347254134714603, 'std': 0.40092015266418457, 'near_zero': 0.5801426569620768},{'mean': -0.024423303082585335, 'std': 0.3910963237285614, 'near_zero': 0.621729850769043},{'mean': 0.005235261749476194, 'std': 0.3965536057949066, 'near_zero': 0.587321917215983}...]"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.activation_stats.stats[0]"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.11852133, 0.15897597, 0.03973317, 0.16110902, 0.06809713,\n",
" 0.15954064, 0.11266836, 0.16073697, 0.15295633, 0.15725476,\n",
" 0.24887388, 0.15945325])"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def variances(learn):\n",
" return np.array([stat['std']**2 for stat in learn.activation_stats.stats[0]])\n",
"\n",
"variances(learn)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.34426928, 0.52678013, 0.56323217, 0.69162092, 0.73921352,\n",
" 0.84022453, 0.90479038, 0.98963761, 1.06411415, 1.13560279,\n",
" 1.24034978, 1.30304291])"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def _init_scales(vars):\n",
" return np.sqrt(np.cumsum(vars))\n",
"scales = _init_scales(variances(learn))\n",
"scales"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define ADMIN Initialisation"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"def admin_init(model, scales):\n",
" ms = res_modules(model)\n",
" for m, s in zip(ms, scales):\n",
" m.w.data *= s"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## LR Find"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"SuggestedLRs(lr_min=0.03019951581954956, lr_steep=0.25118863582611084)"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAr2UlEQVR4nO3dd5xU5b3H8c9vZ3uFZQtlka6CdFZR7DEaexeN/V4VjRpNzDXRmKvJTdEUNVFjRZPYNYiGqBhji70sHQQElra0LbTt9bl/7CxZyQLLsmfOlO/79ZoXM3POzPnuEfztc85TzDmHiIjErji/A4iIiL9UCEREYpwKgYhIjFMhEBGJcSoEIiIxToVARCTGxfsdYG/l5OS4gQMH+h1DRCSizJo1q9w5l9vRtogrBAMHDqSoqMjvGCIiEcXMVu9qmy4NiYjEOBUCEZEYp0IgIhLjVAhERGKcCoGISIxTIRARiXExUwi21zXy6vz1NDS1+B1FRCSsRNw4gq6auWADP3ppATnpSZx/cAEXHLwf/bNTAXDOUVnfRF1D8479HeAcOBzOgRn0zkzGzHz6CUREvBEzheDcCf3Jy0zmmU/X8NB7K3jwvRUckJ9BZV0TZVX1nWop9M5M5rjheXxzRD6HDe5FckIgBMlFRLxlkbZCWWFhodvXkcXrt9by/BdrWVCyley0JHLSE8lJTyIlMUD7X/gNwwziDOqbWvh4eQXvLyujpqGZxEAc+VlJ5KYnkZuRRM/URAJxRnycERdnpCYG6JmaSK/0RLLTkijomcKA7FTiAzFzNU5EwoiZzXLOFXa4LRYLwb6oa2zmk+IKPl1RwabtdZRV1VO6vZ6ttY20tDiaWhwtLY6axmaaW75+bhMCxsBeaQzJTSczJZ7E+DiS4gOkJcUzpiCLgwdlk5mc4NNPJiLRbHeFIGYuDXWX5IQAxx6Qx7EH5O12v5YWx/a6RiqqG9hc3cDqihqWl1axvLSKZaWVVNc309DcQn1jM7WNzbS41pbHiL6ZFA7IZlh+OsPyMhial052WmKIfjoRiUUqBB6JizN6pCbSIzWRIblw8MDsXe5b19jM7DVb+Kx4M58WV/DCF2upbfz3jeteaYkMy09n//wM9s/PYHifTIb3ySA1Uf/5RGTf6f8kYSA5IcCkITlMGpIDtLYm1m+r3dGCWF5axVebKpk+ex1V9U1Aa+thcG46I/tmMnFwL44YmrOjF5SIyN5QIQhDcXFGQc9UCnqmcky7S1DOOdZvq2Px+u0sXL+Nheu28/GKCl6Zux6AQTlpHDUsh28d1JtDBmXrxrSIdIpuFkc45xwryqr4YFk5Hywr55MVFdQ2NpOTnsi3DurNqaP7MnFQNnFxGv8gEsvUayiG1DY0897SUl5dsIF3FpdS29hM/+wUJk/oz7mFBfTJSvE7ooj4QIUgRtU0NPHmok288MVaPimuIM7gmAPyuOSwARw9LFetBJEYokIgrK6o5q9FJbxQtJayynoG9ErlkkMHcF5hf7JSNHZBJNqpEMgODU0tvLFoI09+vIqi1VvITI7n6qOHcPmkgaQlqe+ASLRSIZAOLVy3jd+/9RVvLS6lV1oi3zlmCBcfOkBzKIlEIRUC2a3Za7Zwz5tf8eHycvpkJXPjccM4d0KBup+KRJHdFQL9SxfG79eTp6+cyLNXTSQ/M5lbpi/g+Hvf5+/z1hNpvyiIyN5TIZAdJg3J4eVrJ/HYpYUkBuL47nNzuGjqZ6wqr/Y7moh4SIVAvsbMOH5EPq/feCS/OmsUC9Zt41u/f58H31tOY7NWdxOJRioE0qFAnHHhxP1466ajOfaAPH7zxlLOeOAjisuq/I4mIt1MhUB2Kz8zmYcvmcDDF09gw7ZaTn/gI15fsMHvWCLSjVQIpFNOHNmbV284kqF56Vz7zGx+/uqXulQkEiVUCKTT+vVI4cWrD+Oywwbw+IcrufCxT6moqvc7lojsIxUC2SuJ8XH87IyR/OGCscwr2cZZD37M8lLdNxDx2herNlO6vc6T71YhkC45Y2w/np9yKDUNTZz94Ed8tLzc70giUau5xXHR1M+Y+uFKT75fhUC6bPx+PXn52sPpnZXMZU98zl+L1vodSSQqrd9aS0NTC4Ny0jz5fhUC2Sf9s1OZ9p1JHDq4FzdPm88zn632O5JI1CkODuocHImFwMx6mNk0M1tiZovN7LCdtpuZ3Wdmy81svpmN9zKPeCMzOYGplxXyjQPzuO3lhfzl41V+RxKJKiuD43cG5UZgIQD+ALzhnDsQGAMs3mn7ScCw4GMK8JDHecQjyQkBHr54AieMyOeOGYuY+kGx35FEokZxeTUZSfHkpid58v2eFQIzywKOAh4HcM41OOe27rTbGcCTrtWnQA8z6+NVJvFWYnwcf7xoPCeP6s0vXlusYiDSTYrLqhmUm4aZN6sKetkiGASUAX8yszlmNtXMdm7X9APa32EsCb4nESohEMd9F4zbUQxe/EI3kEX21cryas/uD4C3hSAeGA885JwbB1QDt3Tli8xsipkVmVlRWVlZd2YUD8QH4rj3/LEcOSyHW6bPZ6ampBDpstqGZtZtrWVwbrpnx/CyEJQAJc65z4Kvp9FaGNpbB/Rv97og+N7XOOcedc4VOucKc3NzPQkr3SspPsAjl0xg3H49ufH5uXywTAVcpCtWVbT2GPKq6yh4WAiccxuBtWZ2QPCt44Avd9ptBnBpsPfQocA255x+fYwSqYnxPHHZwQzOTWPKk7OYX7LV70giEae4LNh11KMeQ+B9r6HvAs+Y2XxgLPArM7vGzK4Jbn8dKAaWA48B13qcR0IsKzWBJ684hOy0RKY8OcuzIfIi0WplebDraCS2CACcc3ODl3RGO+fOdM5tcc497Jx7OLjdOeeuc84Ncc6Ncs5pMeIolJeRzGOXFrKttpGrn55FXWOz35FEIkZxWTV9spJJTYz37BgaWSwhMaJvJvdMHsOcNVv5ySsLtRaySCcVl1d72hoAFQIJoZNG9eF73xzGtFklPO7R5Fki0cQ5R3FZlaf3B0CFQELshm8M46SRvfnV64v5WDOWiuxWRXUD2+uaGJzjXddRUCGQEIuLM3533hgG56Zzw/Nz2LhNN49FdmVlcLI5r+YYaqNCICGXlhTPwxePp6ahmeufna0lL0V2oTg42dwQtQgkGg3Ny+Cuc0ZTtHoLv565xO84ImGpuLyaxEAc/XqmeHocFQLxzelj+nLZYQOY+uFKTUMh0oHismoG9EolEOfNZHNtVAjEVz8+ZThj+/fg5mnzWVNR43cckbCyMgRdR0GFQHyWFB/g/m+PwwxueH6O7heIBDU1t7C6otrTyebaqBCI7/pnp3LX2aOZu3Yrd7/5ld9xRMLCuq21NDY7z8cQgAqBhIlTRvfh24f05+F/rdBMpSK0m2xOl4Ykltx+6kEMy0vn+y/Mo7yq3u84Ir5aEew6qktDElNSEgPcf+E4Kusa+Z+/ztN8RBLTVpZXk5WSQM/UBM+PpUIgYeXA3pn8+OThvLe0jKc/Xe13HBHfFJe19hjyap3i9lQIJOxcetgAjto/l1++vnhH81gk1qyu8Had4vZUCCTsmBm/PXc0yQkBvv/CXHUplZhT19jMhu117NcrNSTHUyGQsJSfmcydZ41ifsk27nt7md9xREKqZEsNzsHAXmoRSIw7aVQfzhlfwB/fXc6s1Vv8jiMSMqvKW0fZD1CLQAR+evoI+mSlcPO0eVriUmLG6s1thUAtAhEykhO465xRFJdV8wddIpIYsbqimozk+JB0HQUVAokARw7LZXJhAY++X8yCkm1+xxHx3OqKGgb2Ck3XUVAhkAhx2ykj6JWWyM3T5tHQpF5EEt1WV1SHrMcQqBBIhMhKSeCXZ41iycZKHnpvhd9xRDzT1NxCyZZaBqoQiPyn40fkc/qYvjzw7jKWbNzudxwRT6zfWkdTi2NAdmhuFIMKgUSYO04bQUZyAj+aNp8mDTSTKLSqonXW0VB1HQUVAokwvdKT+OnpBzGvZBtPfLTS7zgi3a6t6+jAEE0vASoEEoFOG92Hbw7P5+43v2JlebXfcUS61eryapIT4sjLSArZMVUIJOKYGb88aySJ8XHc8tJ8Wlo0XbVEj1UVNQzIDl3XUVAhkAiVn5nMT04ZzmcrN/Ps52v8jiPSbdZsrg7p/QFQIZAINrmwP0cMzeGumUtYv7XW7zgi+6ylxbG6okaFQKSzzIw7zx5Fc4vjf19ZqBXNJOJtqqyjvqklZHMMtVEhkIjWPzuVH5ywP28vKeW1BRv8jiOyT1ZXhHbW0TYqBBLxLp80kNEFWfx0xiK21jT4HUeky1YHxxCEah2CNioEEvHiA3HcdfZottQ08qvXF/sdR6TLVlfUkBAw+mQlh/S4nhYCM1tlZgvMbK6ZFXWw/Rgz2xbcPtfMbvcyj0SvEX0zmXLUYF4sKuHj5eV+xxHpktUVNRT0TCU+ENrf0UNxtGOdc2Odc4W72P5BcPtY59z/hSCPRKkbjxvGwF6p3PryAi1iIxFpVUXou46CLg1JFElOCPCrs0axuqKGB95Z7ncckb3inGNNcB2CUPO6EDjgTTObZWZTdrHPYWY2z8xmmtlBHueRKDdpaA5nj+/HI++v4KtNlX7HEem0zdUNVNY3sV929LUIjnDOjQdOAq4zs6N22j4bGOCcGwPcD7zS0ZeY2RQzKzKzorKyMk8DS+S77eThpCfFc9vLCzT9hESMf082F2WFwDm3LvhnKfAycMhO27c756qCz18HEswsp4PvedQ5V+icK8zNzfUyskSBXulJ/Pjk4XyxagsvFq31O45Ip6zeMf10FF0aMrM0M8toew6cACzcaZ/eFpxZycwOCeap8CqTxI5zJxQwcVA2v3p9MWWV9X7HEdmjVeU1mEFBz5SQH9vLFkE+8KGZzQM+B15zzr1hZteY2TXBfc4FFgb3uQ+4wGmeAOkGrTOUjqKusYWfv/ql33FE9mjt5hr6ZqWQFB8I+bHjvfpi51wxMKaD9x9u9/wB4AGvMkhsG5qXznXHDuXet77ijLF9OW54vt+RRHZpzeYa+meHvjUA6j4qUe47xwzhgPwMbnt5IZV1jX7HEdmlNZtrfOkxBCoEEuUS4+P49bmjKa2s466ZS/yOI9KhusZmSivr6d9ThUDEE2P79+CKIwbxzGdr+LRYfREk/JRsae06up8Po4pBhUBixE3HH8B+2anc8tJ8TT8hYWdNcAxBf10aEvFOSmKAu84ZxaqKGu596yu/44h8zdrNrSvs6dKQiMcmDcnh/ML+TP1gJYvWb/M7jsgOazbXkJIQICc90ZfjqxBITLn15APpmZrArdMX0KzpJyRMtPUYCo6vDTkVAokpPVITueO0g5hfso0/f7zK7zgiQOtgMr/GEIAKgcSgU0f34dgDcrn7zaU7emuI+MU5FywE/twfABUCiUFmxs/PHAnA/76yEM1qIn7aXN1AdUOzb4PJQIVAYlRBz1R+cMIBvLu0jBnz1vsdR2LYjq6jPvUYAhUCiWGXTxrI2P49+OmMRZqhVHyzdktr11G/BpOBCoHEsECc8bvzRlPd0KxLROKbtZHSIgiuLRAXfL6/mZ1uZgneRhPx3tC8DL73zWG8sWgjry3Y4HcciUFrKmrISU8iJTH000+36WyL4H0g2cz6AW8ClwB/9iqUSChNOXIwowuyuP1vi6io0iUiCa21W2rYz8euo9D5QmDOuRrgbOBB59x5gBaal6gQH4jjt+eOobKukdtnLPI7jsQYP6efbtPpQmBmhwEXAa8F3/OvHSPSzQ7oncGNxw3jtfkbeHW+ehFJaDQ2t7B+a62vYwig84Xge8CtwMvOuUVmNhh417NUIj645ughjCnI4ievLKR0e53fcSQGbNhaR4vzb9bRNp0qBM65fznnTnfO/Tp407jcOXeDx9lEQio+EMfdk8dS29DMLdMXqBeReK5tDEFEXBoys2fNLNPM0oCFwJdmdrO30URCb2heOj888UDeWVLKi0Vr/Y4jUc7vdQjadPbS0Ajn3HbgTGAmMIjWnkMiUee/Jg3k0MHZ/N/fv9zRx1vEC2u31JAQMHpnJvuao7OFICE4buBMYIZzrhFQu1miUlyc8dtzx2Bm/M9f59Gi6arFI2s211DQM5VAnD/TT7fpbCF4BFgFpAHvm9kAYLtXoUT81j87ldtPHcFnKzfzxEcr/Y4jUWrt5hoKevo7hgA6f7P4PudcP+fcya7VauBYj7OJ+Oq8wgK+OTyP3/xjKcs2VfodR6LQ2jAYQwCdv1mcZWb3mFlR8HE3ra0DkahlZtx59mjSk+L5/otzaWxu8TuSRJHtdY1sqWmMnEIAPAFUApODj+3An7wKJRIucjOS+NVZI1m4bjv3v7Pc7zgSRdaGSY8h6HwhGOKcu8M5Vxx8/AwY7GUwkXBx4sg+nD2+H398dzlz1271O45EibVhMoYAOl8Ias3siLYXZnY4UOtNJJHwc8dpB5GfkcRNL8yltqHZ7zgSBVaWBwuBj+sQtOlsIbgG+KOZrTKzVcADwNWepRIJM1kpCfz2vDEUl1fz6zeW+B1HosCKsiryMpLITPZ/Rv/O9hqa55wbA4wGRjvnxgHf8DSZSJg5fGgOl08ayJ8/XsWHy8r9jiMRbkVZFUNy0/2OAezlCmXOue3BEcYAN3mQRySs/ejEAxmcm8bN0+axrbbR7zgSoZxzrCitYkheeHS+3JelKv0dCifig5TEAPdMHktpZT0/09oF0kXlVQ1sr2uKzBbBTjTuXmLS2P49uO7YoUyfs46ZWt5SumBFWRVAZBQCM6s0s+0dPCqBviHKKBJ2vvuNoYzql8Ut0xewYZs60Mne2VEI8iKgEDjnMpxzmR08Mpxz8aEKKRJuEgJx/OGCsTQ0tXDTC/No1sR0shdWlFaTkhCgj8+zjrbZl0tDexTsbrrAzOaaWVEH283M7jOz5WY238zGe5lHpDsNzk3nZ6cfxCfFFTz6frHfcSSCrCirYnBuGnE+zzraxtNCEHSsc26sc66wg20nAcOCjynAQyHII9Jtziss4ORRvbn7zaXML9nqdxyJEOHUdRRCUwh25wzgyeCMpp8CPcysj8+ZRDrNzLjzrNHkZSRxw3NzqK5v8juShLnahmbWba2NqULggDfNbJaZTelgez+g/XqAJcH3vsbMprTNfFpWVuZRVJGuyUpN4N7zx7Jmcw23aq1j2YOV5dU4R9iMIQDvC8ERzrnxtF4Cus7MjurKlzjnHnXOFTrnCnNzc7s3oUg3mDi4Fz844QBmzFvPU5+u9juOhLFw6zoKHhcC59y64J+lwMvAITvtsg7o3+51QfA9kYjznaOHcNyBefz81S+ZvWaL33EkTK0oq8IMBuXEQIvAzNLMLKPtOXACsHCn3WYAlwZ7Dx0KbHPOaYSORKS4OOOeyWPpnZXMdc/MpqKq3u9IEoZWlFVT0DOF5ISA31F28LJFkA98aGbzgM+B15xzb5jZNWZ2TXCf14FiYDnwGHCth3lEPJeVmsBDF02gorqBG5+fq/EF8h9WlIZXjyEAzwaFOeeKgTEdvP9wu+cOuM6rDCJ+GNkvi1+cMZIfvjSfX7+xhB+fPNzvSBImWlocxeVVHDakl99Rvkajg0U8MPng/ixYt41H3y/mwN4ZnD2+wO9IEgbWb6ulrrEl7FoEfo8jEIlat582gkMHZ3PL9AVa4lKA1vsDAEPDZI6hNioEIh5JCMTx4EUTyMtI4uqniijdXud3JPHZitK2rqPh02MIVAhEPJWdlshjlxZSWdfElKdmUdeo9Y5j2YqyKnqkJpCdluh3lK9RIRDx2PA+mdwzeSzzSrby/Rfm0qKeRDGrbY4hs/CYbK6NCoFICJw4sje3nTycmQs3ctcbS/yOIz5ZUVYddpeFQL2GRELmiiMGsXZzDY++X0z/nilccthAvyNJCG2rbaSssj7segyBCoFIyJgZt592EOu21nLHjEX07ZHCccPz/Y4lIbK8NPzmGGqjS0MiIRSIM+779jgO6pvF9c/O0RoGMWTpxkoADuid4XOS/6RCIBJiqYnxPH55Ib3SE/nvP3/BmooavyNJCCzduJ20xAD9eqT4HeU/qBCI+CAvI5m//PchNLU4LvvT52yubvA7knhsycZKDuidETbLU7anQiDikyG56Uy9tJD1W2u58i9fUNugMQbRyjkXLASZfkfpkAqBiI8KB2bzhwvGMmftVq5/djaNzS1+RxIPbNpez7baRg4Mw/sDoEIg4rsTR/bh52eM5O0lpdz04jxNXR2FlmzcDoTnjWJQ91GRsHDxoQOoqm/irplLSEsMcOfZo8Ju9Kl0XVuPoXBtEagQiISJa44eQlVdEw+8u5y0pHh+cspwFYMosXRjJfmZSfRIDa85htqoEIiEkR+csD9V9U08/uFK0pLiuen4/f2OJN0gnG8UgwqBSFgxM24/dQTV9U3c9/YyUhMDXHP0EL9jyT5oam5heVkVRwzL8TvKLqkQiISZuDjjrnNGU9fUwl0zl5CSEOCySQP9jiVdtKqimoamFg7ID8/7A6BCIBKWAnHGPZPHUNfYzB0zFpGSEGDywf39jiVdsKTtRnGf8C0E6j4qEqYSAnE8cOE4jhyWw4+mz+dvc9f5HUm6YOnGSgJxFnbLU7anQiASxpLiAzx6SSETB2Vz04vzeH3BBr8jyV5asrGSQTlpJMUH/I6ySyoEImEuJTHA45cdzNj+PbjhuTm89eUmvyPJXlganGMonKkQiESAtKR4/vRfB3NQ30yufWY27y0t9TuSdEJVfRNrNtdwYBjfKAYVApGIkZmcwJP/PZGheelMeWoW//qqzO9IsgdfbQrfNQjaUyEQiSBZqQk8c+VEhuSmc9WTRWoZhLl/Ty0RvoPJQIVAJOL0TEvk2SsnMiwvnSlPzuLdJSoG4WrpxkpSEwMU9Ay/xWjaUyEQiUA90xJ55sqJ7N87naufmsU7S3QDORwt2bid/fPDczGa9lQIRCJUj9REnrniUA7oncHVT81iprqWhhXnHEs3VjI8jAeStVEhEIlgWakJPH3lREb1y+L65+bwyhwNOgsXs9dsYUtNI4UDsv2OskcqBCIRLislgaeumMghA7P5/otzee7zNX5HEuCVOetJTojjhIPy/Y6yRyoEIlGgbZzB0fvncuv0Bfzpo5V+R4ppjc0tvLZgA98cnk9GcoLfcfZIhUAkSiQnBHjkkgl866B8fvb3L3nwveV+R4pZHywrY3N1A2eO7ed3lE5RIRCJIknxAf544XjOGNuX37yxlHveXIpzWgM51F6es54eqQkctX+u31E6xfNCYGYBM5tjZq92sO1yMyszs7nBx5Ve5xGJdvGBOO6ZPJbzC/tz3zvL+eVri1UMQqiqvol/frmRU0f3ITE+Mn7XDsV6BDcCi4FdDa17wTl3fQhyiMSMQJxx59mjSEkMMPXDlVQ3NPGLM0cRCPP+7NHgzUUbqWtsiZjLQuBxi8DMCoBTgKleHkdE/lNcnHHHaSO47tghPPf5Wq5/djb1Tc1+x4p6r8xdT0HPFCYM6Ol3lE7zut3ye+CHQMtu9jnHzOab2TQz63AJJjObYmZFZlZUVqaJtkQ6y8y4+VsH8pNThjNz4Ub+609fUFXf5HesqFVaWceHy8o4Y2xfzCKn9eVZITCzU4FS59ys3ez2d2Cgc2408E/gLx3t5Jx71DlX6JwrzM2NjJsvIuHkyiMHc/d5Y/hs5WYufOxTyqvq/Y4UlV6dt4EWR0RdFgJvWwSHA6eb2SrgeeAbZvZ0+x2ccxXOuba/kVOBCR7mEYlp50wo4NFLJvDVpkrOfvBjVpRV+R0p6rwydx0j+mQyLMzXH9iZZ4XAOXerc67AOTcQuAB4xzl3cft9zKxPu5en03pTWUQ8ctzwfJ6fchjV9U2c89DHfLFqs9+RosayTZXML9nG2eMjqzUAPowjMLP/M7PTgy9vMLNFZjYPuAG4PNR5RGLN2P49ePnaw8lOTeSiqZ/x6vz1fkeKCi/NXkcgzjgjwi4LAVik9S8uLCx0RUVFfscQiXhbqhu46skiilZv4QfH78/13xgaUTc4w0lzi2PSXW8zsm8Wj19+sN9xOmRms5xzhR1ti4zRDiLS7XqmJfL0lRM5a1w/7v7nV9zw/FzqGtW9tCs+Wl7Opu31nD2+wO8oXRKKAWUiEqaSEwLcM3kM++dn8Jt/LGFNRTWPXlpIfmay39EiykuzS8hMjue44Xl+R+kStQhEYpyZ8Z1jhvDoJYUsK63itPs/ZPaaLX7HihiVdY38Y9FGThvTl+SEgN9xukSFQEQAOH5EPtOvnURyQoALHvmU57WuQafMXNA6pcQ5EyLzshCoEIhIOwf2zmTG9YczcXA2t0xfwG0vL6ChaXcTA8i02SUMykljXP8efkfpMhUCEfmaHqmJ/Pm/DuGao4fwzGdrmPzIJ6zbWut3rLC0dnMNn6/czDnj+0V0jysVAhH5D4E445aTDuTBi8azvLSKk//wAW8v3uR3rLDz0uwSzOCsCO0t1EaFQER26eRRfXj1u0fQr0cKV/yliDtfX0xjsy4VATjnmD57HYcO6kW/Hil+x9knKgQislsDc9KYfu0kLpy4H4+8X8z5j3xCyZYav2P5btbqLazZXBPRN4nbqBCIyB4lJwT41VmjuO/b4/hqU+ulopkLNvgdy1cvzS4hJSHAiSN7+x1ln6kQiEinnT6mL6/fcCSDctL4zjOz+ckrC6htiL3RyHWNzbw6fwMnjexNelLkj8tVIRCRvbJfr1T+es0krj5qME9/uoZT7v+AuWu3+h0rpN5avInKuqaInVJiZyoEIrLXEuPjuPXk4Txz5UTqGpo556GPuefNpTFzI/mlWSX0yUrmsCG9/I7SLVQIRKTLDh+awxvfP4ozx/bjvneWc9aDH7FsU6XfsTxVWlnH+8vKOXNcPwJxkTt2oD0VAhHZJ5nJCdw9eQwPXzyB9VvrOOX+D5n6QTEtLZE1xX1nzZi7nuYWxzkRuADNrqgQiEi3OHFkb/7xvaM4algOv3htMRdN/Swqu5m+NHsdYwqyGJoXWctR7o4KgYh0m9yMJB67tJBfnzOK+SVb+da97zP1g2KaouTewZfrt7N4w/aouUncRoVARLqVmXH+wfvxxveO4pBB2fzitcWc9sBHET+1dVNzC3fMWEhaYoDTxvT1O063UiEQEU/0z07licsP5uGLx7O1poFzHvqYH06bR+n2Or+jdckf3l7GF6u28MuzRpGdluh3nG6lQiAinjEzThzZh3/edDRXHTmYl+es45jfvcd9by+LqIFoHy0v54F3l3PehALOHBc9N4nbqBCIiOfSk+L58cnDeeumozl6/1zu+edXHPu793hlzjqcC+/eRWWV9XzvhbkMyU3nZ2cc5HccT6gQiEjIDOiVxkMXT+DFqw8jNyOJ770wl3Mf/oQFJdv8jtahlhbHD/46j221jTxw4ThSEyN/OomOqBCISMgdMiibv113OL85ZzSrK6o5/Y8f8qNp8ymtDK/7B098tJL3vyrj9lNHcGDvTL/jeEaFQER8ERdnTD64P+/8zzFcdeRgps8p4djfvscf311OXaP/9w8Wrd/Gb95Yygkj8rlo4n5+x/GUCoGI+CozOYEfnzycN79/NEcMy+G3/1jKN373Hi/PKaHZp9HJtQ3N3Pj8XHqkJnDXOaMjehnKzlAhEJGwMCgnjUcuKeS5qw6lZ1oi339hHifc+y/+NnddyAvCnTMXs7y0irsnj4m6rqIdUSEQkbBy2JBe/P36I3jgwnEE4owbn5/L8ff+i78WrQ3JJaO3F2/iyU9Wc+URgzhyWK7nxwsHFu5dt3ZWWFjoioqK/I4hIiHQ0uJ4Y9FG7nt7GUs2VpKVksDZ4/tx0cT9PJnrZ+nGSr792KfkZSTxt+sPJyk+0O3H8IuZzXLOFXa4TYVARMKdc45Piit49rM1/GPRRhqbHQcP7Mnkwv6cPKoPad2wStjCddu4+PHPSIqP47mrDmVwbno3JA8fKgQiEjXKq+qZNquEF79YS3F5NWmJAU4d3ZfJBxcwfr+eXbqxO2v1Fi7/0+dkJifw7FUTGdArzYPk/lIhEJGo45xj1uotvFi0llfnb6CmoZkhuWlMLuzPWeP7kZeR3Knvef+rMq55ehZ5GUk8c9Wh9OuR4nFyf6gQiEhUq65v4rUFG3jxi7UUrd6CGfTvmcqwvHSG5qezf14GI/tlMSQ3jfhAHM0tjn9+uZGpH6ykaPUWhual8+yVE8nL7FzxiEQqBCISM1aUVfH6/A0s3VTJ8tIqisuqaQiuh5CcEMeIPpmUVzWwZnMNBT1T+O/DB3H+wf275T5DONtdIYjun1xEYs6Q3HS+e9ywHa+bmltYWV7NwvXbWFCynYXrttGvRwq3nnQgJxzUO2rWHd4XnhcCMwsARcA659ypO21LAp4EJgAVwPnOuVVeZxKR2BEfiGNYfgbD8jM4a5zfacJTKAaU3Qgs3sW2K4AtzrmhwL3Ar0OQR0RE2vG0EJhZAXAKMHUXu5wB/CX4fBpwnEX7pB4iImHG6xbB74EfArtaubofsBbAOdcEbAN67byTmU0xsyIzKyorK/MoqohIbPKsEJjZqUCpc27Wvn6Xc+5R51yhc64wNzc25v4QEQkVL1sEhwOnm9kq4HngG2b29E77rAP6A5hZPJBF601jEREJEc8KgXPuVudcgXNuIHAB8I5z7uKddpsBXBZ8fm5wn8ga2CAiEuFCPo7AzP4PKHLOzQAeB54ys+XAZloLhoiIhFBICoFz7j3gveDz29u9XwecF4oMIiLSsYibYsLMyoCttPYwapPV7nVHz9v+zAHKu3jo9t+7N9s7en9Xedu/7mifrubfU/bd7bOrfB293t25B+/yd/Xc7/xa577z2fa0vavnvv3zcPp3u7fZ2z8Pl3PfwznXcW8b51zEPYBHd/W6o+ft/izqrmN2dntH7+8qb0eZuyP/nrLvTf6unnsv83f13Hcys859CM99R/nD4d/t3mYPxd+dfTn3Oz8idanKv+/mdUfPd96/O47Z2e0dvb+rvO1f726fvdWZz3c2fzSd+51f69zvOUNnt3f13Ld/Hk759zZ7Z469J16e+6+JuEtD+8LMitwuZt+LBMrvn0jODpGdP5KzQ2Tkj9QWQVc96neAfaT8/onk7BDZ+SM5O0RA/phqEYiIyH+KtRaBiIjsRIVARCTGqRCIiMQ4FYIgMzvSzB42s6lm9rHfefaWmcWZ2S/N7H4zu2zPnwgfZnaMmX0QPP/H+J2nK8wsLThV+ql73jt8mNnw4HmfZmbf8TvP3jKzM83sMTN7wcxO8DvP3jKzwWb2uJlN8zNHVBQCM3vCzErNbOFO759oZkvNbLmZ3bK773DOfeCcuwZ4lX8vlhMS3ZGf1kV+CoBGoMSrrDvrpuwOqAKSCWF26Lb8AD8CXvQmZce66e/94uDf+8m0zhgcMt2U/xXn3FXANcD5XubdWTflL3bOXeFt0k7o6oi9cHoARwHjgYXt3gsAK4DBQCIwDxgBjKL1f/btH3ntPvcikBFp+YFbgKuDn50WYdnjgp/LB56JwHN/PK0TJl4OnBpJ2YOfOR2YCVwYaee+3efuBsZHcP6Q/Zvt6BHy2Ue94Jx738wG7vT2IcBy51wxgJk9D5zhnLsT6LD5bmb7Aducc5Ve5t1Zd+Q3sxKgIfiy2cO4X9Nd5z5oC5DkSdBd6KZzfwyQRus/+Foze905t6tV+bpNd5171zoT8Awzew141sPIOx+3O869AXcBM51zsz2O/DXd/HffV1FRCHZhxzKYQSXAxD185grgT54l2jt7m386cL+ZHQm872WwTtir7GZ2NvAtoAfwgKfJOmev8jvnbgMws8uB8lAUgd3Y23N/DHA2rQX4dS+DddLe/r3/LvBNIMvMhjrnHvYyXCfs7fnvBfwSGGdmtwYLRshFcyHYa865O/zO0FXOuRpaC1nEcc5Np7WQRTTn3J/9zrC3XLsp4iORc+4+4D6/c3SVc66C1vsbvoqKm8W7sGMZzKCC4HuRIpLzR3J2iOz8kZwdlN8X0VwIvgCGmdkgM0uk9WbeDJ8z7Y1Izh/J2SGy80dydlB+f/h5p7ob794/B2zg310nrwi+fzLwFa138W/zO2c05o/k7JGeP5KzK394PTTpnIhIjIvmS0MiItIJKgQiIjFOhUBEJMapEIiIxDgVAhGRGKdCICIS41QIJCqYWVWIj9cta1YE12LYZmZzzWyJmf2uE58508xGdMfxRUCFQKRDZrbbebicc5O68XAfOOfGAuOAU81sT+sCnEnrTKci3UKFQKKWmQ0xszfMbJa1roB2YPD908zsMzObY2ZvmVl+8P2fmtlTZvYR8FTw9RNm9p6ZFZvZDe2+uyr45zHB7dOCv9E/E5waGTM7OfjeLDO7z8xe3V1e51wtMJfWGSwxs6vM7Aszm2dmL5lZqplNonX9gN8GWxFDdvVzinSWCoFEs0eB7zrnJgD/AzwYfP9D4FDn3DjgeeCH7T4zAvimc+7bwdcH0jpF9iHAHWaW0MFxxgHfC352MHC4mSUDjwAnBY+fu6ewZtYTGMa/pxGf7pw72Dk3BlhM6xQGH9M6d83NzrmxzrkVu/k5RTpF01BLVDKzdGAS8NfgL+jw70VvCoAXzKwPratIrWz30RnB38zbvOacqwfqzayU1lXUdl5O83PnXEnwuHOBgbQuvVnsnGv77ueAKbuIe6SZzaO1CPzeObcx+P5IM/sFres0pAP/2MufU6RTVAgkWsUBW4PX3nd2P3CPc25GcGGWn7bbVr3TvvXtnjfT8b+ZzuyzOx845041s0HAp2b2onNuLvBn4Ezn3LzgojfHdPDZ3f2cIp2iS0MSlZxz24GVZnYetC5paGZjgpuz+Pcc8Zd5FGEpMLjdUoZ7XFg92Hq4C/hR8K0MYEPwctRF7XatDG7b088p0ikqBBItUs2spN3jJlr/53lF8LLLIuCM4L4/pfVSyiyg3IswwctL1wJvBI9TCWzrxEcfBo4KFpD/BT4DPgKWtNvneeDm4M3uIez65xTpFE1DLeIRM0t3zlUFexH9EVjmnLvX71wiO1OLQMQ7VwVvHi+i9XLUI/7GEemYWgQiIjFOLQIRkRinQiAiEuNUCEREYpwKgYhIjFMhEBGJcSoEIiIx7v8BFPZ8YGPPSlYAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.remove_cb(ActivationStats)\n",
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"SuggestedLRs(lr_min=0.04365158379077912, lr_steep=0.3630780577659607)"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"admin_init(learn.model, scales)\n",
"learn.lr_find()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tracking"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmorgan\u001b[0m (use `wandb login --relogin` to force relogin)\n"
]
}
],
"source": [
"# hide\n",
"!wandb login"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load Experiment Tracking with Weights & Biases:"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"import wandb\n",
"\n",
"WANDB_NAME = f'admin_init_bte2'\n",
"GROUP = 'admin'\n",
"NOTES = 'Early ADMIN testing'\n",
"CONFIG = {}\n",
"TAGS = ['admin']"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmorgan\u001b[0m (use `wandb login --relogin` to force relogin)\n"
]
},
{
"data": {
"text/html": [
"\n",
" Tracking run with wandb version 0.10.19<br/>\n",
" Syncing run <strong style=\"color:#cdcd00\">admin_init_bte2</strong> to <a href=\"https://wandb.ai\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
" Project page: <a href=\"https://wandb.ai/fastai_community/transformers_sandbox\" target=\"_blank\">https://wandb.ai/fastai_community/transformers_sandbox</a><br/>\n",
" Run page: <a href=\"https://wandb.ai/fastai_community/transformers_sandbox/runs/2o27v7d4\" target=\"_blank\">https://wandb.ai/fastai_community/transformers_sandbox/runs/2o27v7d4</a><br/>\n",
" Run data is saved locally in <code>/home/morgan/ml/projects/transformers_sandbox/nbs/exploration/wandb/run-20210220_234952-2o27v7d4</code><br/><br/>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<h1>Run(2o27v7d4)</h1><iframe src=\"https://wandb.ai/fastai_community/transformers_sandbox/runs/2o27v7d4\" style=\"border:none;width:100%;height:400px\"></iframe>"
],
"text/plain": [
"<wandb.sdk.wandb_run.Run at 0x7fb00ef616a0>"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#hide_output\n",
"wandb.init(reinit=True, project=\"transformers_sandbox\", entity=\"fastai_community\", \n",
" name=WANDB_NAME, group=GROUP, notes=NOTES, tags=TAGS, config=CONFIG)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>perplexity</th>\n",
" <th>bpc</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>2.117492</td>\n",
" <td>2.035515</td>\n",
" <td>0.404338</td>\n",
" <td>7.656197</td>\n",
" <td>2.936628</td>\n",
" <td>07:32</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.601835</td>\n",
" <td>1.523893</td>\n",
" <td>0.554345</td>\n",
" <td>4.590059</td>\n",
" <td>2.198513</td>\n",
" <td>07:29</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.452392</td>\n",
" <td>1.390810</td>\n",
" <td>0.589804</td>\n",
" <td>4.018106</td>\n",
" <td>2.006516</td>\n",
" <td>07:28</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1.347416</td>\n",
" <td>1.315846</td>\n",
" <td>0.613917</td>\n",
" <td>3.727902</td>\n",
" <td>1.898364</td>\n",
" <td>07:29</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>1.275978</td>\n",
" <td>1.247816</td>\n",
" <td>0.633069</td>\n",
" <td>3.482728</td>\n",
" <td>1.800218</td>\n",
" <td>07:28</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>1.227183</td>\n",
" <td>1.200533</td>\n",
" <td>0.647163</td>\n",
" <td>3.321887</td>\n",
" <td>1.732003</td>\n",
" <td>07:27</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>1.188224</td>\n",
" <td>1.171363</td>\n",
" <td>0.655419</td>\n",
" <td>3.226386</td>\n",
" <td>1.689919</td>\n",
" <td>07:28</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>1.147837</td>\n",
" <td>1.153088</td>\n",
" <td>0.661562</td>\n",
" <td>3.167962</td>\n",
" <td>1.663555</td>\n",
" <td>07:28</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>1.123723</td>\n",
" <td>1.142603</td>\n",
" <td>0.665061</td>\n",
" <td>3.134917</td>\n",
" <td>1.648427</td>\n",
" <td>07:29</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>1.113917</td>\n",
" <td>1.141676</td>\n",
" <td>0.665857</td>\n",
" <td>3.132012</td>\n",
" <td>1.647090</td>\n",
" <td>07:28</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Better model found at epoch 0 with valid_loss value: 2.035515308380127.\n",
"Better model found at epoch 1 with valid_loss value: 1.523892879486084.\n",
"Better model found at epoch 2 with valid_loss value: 1.390810489654541.\n",
"Better model found at epoch 3 with valid_loss value: 1.3158457279205322.\n",
"Better model found at epoch 4 with valid_loss value: 1.2478158473968506.\n",
"Better model found at epoch 5 with valid_loss value: 1.2005329132080078.\n",
"Better model found at epoch 6 with valid_loss value: 1.1713625192642212.\n",
"Better model found at epoch 7 with valid_loss value: 1.1530883312225342.\n",
"Better model found at epoch 8 with valid_loss value: 1.142602562904358.\n",
"Better model found at epoch 9 with valid_loss value: 1.1416757106781006.\n"
]
}
],
"source": [
"admin_init(learn.model, scales)\n",
"learn.fit_one_cycle(n_epochs, 1e-3, cbs=[WandbCallback(log_model=False, log_preds=False)])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train - SPUniTokenizer\n",
"\n",
"Train a tokenizer with the same dictionary length as the T5 tokenizer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"SentencePiece Unigram Tokenizer"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Reusing dataset wikitext (/home/morgan/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91)\n",
"Loading cached processed dataset at /home/morgan/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-e9eccb98202e10a0.arrow\n"
]
}
],
"source": [
"big_train_ds = load_dataset('wikitext', name='wikitext-103-raw-v1', split='train')\n",
"big_train_ds = big_train_ds.filter(lambda x: x['text'] != '')\n",
"\n",
"spu_tok = SPUniTokenizer(vcb_sz=32093, add_eos=True, add_bos=True)\n",
"spu_tok.train(big_train_ds['text'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Save Trained Tokenizer"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"def save_obj(obj, name ):\n",
" with open(name + '.pkl', 'wb') as f:\n",
" pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)\n",
"\n",
"def load_obj(name ):\n",
" with open(name + '.pkl', 'rb') as f:\n",
" return pickle.load(f)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# save_obj(spu_tok)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"spu_tok = load_obj('sentencepieceunigram_wiki103_32093')"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"tok = spu_tok"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"#Run setup\n",
"n_epochs = 10\n",
"bs = 12\n",
"sl = 512\n",
"n_layers = 6\n",
"pad_id = tok.pad_token_id\n",
"vocab_size = tok.vocab_size"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Tokenize Data"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading cached processed dataset at /home/morgan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-4267024924ef223b.arrow\n",
"Loading cached processed dataset at /home/morgan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-916f06ba6ec602c9.arrow\n"
]
}
],
"source": [
"def encode(examples): return {'token_ids' : spu_tok(examples['text'])}\n",
"\n",
"tok_ds = train_ds.map(encode, batched=False)\n",
"val_tok_ds = valid_ds.map(encode, batched=False)\n",
"all_toks = tok_ds['token_ids'] + val_tok_ds['token_ids']\n",
"df['toks'] = all_toks "
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>text_</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Coulthard and Button collided on lap 18 when Button attempted to pass the Red Bull on the inside at turn eight ; the Honda lost its front wing and retired a lap later after two pit stops . Hamilton continued his climb back through the field ; he moved from 18th , passing Piquet , Davidson , Sutil and Bourdais in separate manoeuvres , to sit in 14th by the time he pitted on lap 31 . Piquet retired on lap 42 with transmission failure , requiring a gearbox change before the next race . Anderson was involved in filmmaking at a young age and never really had an alternative plan to directing films . He made his first movie when he was eight years old and started making movies on a Betamax video camera which his dad bought in 1982 when he was twelve years old . He</td>\n",
" <td>Coulthard and Button collided on lap 18 when Button attempted to pass the Red Bull on the inside at turn eight ; the Honda lost its front wing and retired a lap later after two pit stops . Hamilton continued his climb back through the field ; he moved from 18th , passing Piquet , Davidson , Sutil and Bourdais in separate manoeuvres , to sit in 14th by the time he pitted on lap 31 . Piquet retired on lap 42 with transmission failure , requiring a gearbox change before the next race . Anderson was involved in filmmaking at a young age and never really had an alternative plan to directing films . He made his first movie when he was eight years old and started making movies on a Betamax video camera which his dad bought in 1982 when he was twelve years old . He</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>subdue and molest his victims . The song ends with the narrator turning inward with the lyrics : \" And in my best behavior , I am really just like him / Look beneath the floorboards for the secrets I have hid . \" Stevens stated in a 2009 interview with Paste that \" we 're all capable of what [ Gacy ] did . \" Technically , it is known as a speaker that easily reveals poor quality in recordings . Recording engineers sought to dull its treble response by hanging tissue paper in front of it , resulting in what became known as the \" tissue paper effect \" , a type of comb filtering . The NS @-@ 10 has been used to monitor a large number of successful recordings by numerous artists , leading Gizmodo to refer to it as \" the most important loudspeaker</td>\n",
" <td>and molest his victims . The song ends with the narrator turning inward with the lyrics : \" And in my best behavior , I am really just like him / Look beneath the floorboards for the secrets I have hid . \" Stevens stated in a 2009 interview with Paste that \" we 're all capable of what [ Gacy ] did . \" Technically , it is known as a speaker that easily reveals poor quality in recordings . Recording engineers sought to dull its treble response by hanging tissue paper in front of it , resulting in what became known as the \" tissue paper effect \" , a type of comb filtering . The NS @-@ 10 has been used to monitor a large number of successful recordings by numerous artists , leading Gizmodo to refer to it as \" the most important loudspeaker you</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"df['lens'] = df['toks'].apply(len)\n",
"splits = ColSplitter()(df)\n",
"tfms = [attrgetter(\"text\"), spu_tok]\n",
"# tfms = [attrgetter(\"toks\")]\n",
"dsets = Datasets(df, [tfms], splits=splits, dl_type=LMDataLoader)\n",
"dl_kwargs = [{'lens':df['lens'].values[splits[0]]},\n",
" {'val_lens':df['lens'].values[splits[1]]}]\n",
"dls = dsets.dataloaders(bs=bs, seq_len=sl, dl_kwargs=dl_kwargs, shuffle_train=True)\n",
"dls.show_batch(max_n=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Tracking"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"WANDB_NAME = f'admin_init_spu'\n",
"GROUP = 'admin'\n",
"NOTES = 'SentencePieceUnigram tokenizer trained on wiki 103, 32100 vcb_sz, Early ADMIN testing'\n",
"CONFIG = {}\n",
"TAGS = ['admin','spu']"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmorgan\u001b[0m (use `wandb login --relogin` to force relogin)\n"
]
},
{
"data": {
"text/html": [
"\n",
" Tracking run with wandb version 0.10.19<br/>\n",
" Syncing run <strong style=\"color:#cdcd00\">admin_init_spu</strong> to <a href=\"https://wandb.ai\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
" Project page: <a href=\"https://wandb.ai/fastai_community/transformers_sandbox\" target=\"_blank\">https://wandb.ai/fastai_community/transformers_sandbox</a><br/>\n",
" Run page: <a href=\"https://wandb.ai/fastai_community/transformers_sandbox/runs/896aogrk\" target=\"_blank\">https://wandb.ai/fastai_community/transformers_sandbox/runs/896aogrk</a><br/>\n",
" Run data is saved locally in <code>/home/morgan/ml/projects/transformers_sandbox/nbs/exploration/wandb/run-20210221_183500-896aogrk</code><br/><br/>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<h1>Run(896aogrk)</h1><iframe src=\"https://wandb.ai/fastai_community/transformers_sandbox/runs/896aogrk\" style=\"border:none;width:100%;height:400px\"></iframe>"
],
"text/plain": [
"<wandb.sdk.wandb_run.Run at 0x7f5b92f19e50>"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#hide_output\n",
"wandb.init(reinit=True, project=\"transformers_sandbox\", entity=\"fastai_community\", \n",
" name=WANDB_NAME, group=GROUP, notes=NOTES, tags=TAGS, config=CONFIG)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train SPU"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>perplexity</th>\n",
" <th>bpc</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>5.331179</td>\n",
" <td>5.275365</td>\n",
" <td>0.253242</td>\n",
" <td>195.461777</td>\n",
" <td>7.610743</td>\n",
" <td>02:44</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>4.771585</td>\n",
" <td>4.862295</td>\n",
" <td>0.290732</td>\n",
" <td>129.320679</td>\n",
" <td>7.014809</td>\n",
" <td>02:43</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>4.419650</td>\n",
" <td>4.540506</td>\n",
" <td>0.322093</td>\n",
" <td>93.738213</td>\n",
" <td>6.550565</td>\n",
" <td>02:46</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>4.108190</td>\n",
" <td>4.304454</td>\n",
" <td>0.344261</td>\n",
" <td>74.028770</td>\n",
" <td>6.210014</td>\n",
" <td>02:44</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>3.865303</td>\n",
" <td>4.159061</td>\n",
" <td>0.357382</td>\n",
" <td>64.011414</td>\n",
" <td>6.000257</td>\n",
" <td>02:45</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>3.655769</td>\n",
" <td>4.056620</td>\n",
" <td>0.368161</td>\n",
" <td>57.778694</td>\n",
" <td>5.852466</td>\n",
" <td>02:44</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>3.459788</td>\n",
" <td>3.989344</td>\n",
" <td>0.376501</td>\n",
" <td>54.019436</td>\n",
" <td>5.755407</td>\n",
" <td>02:45</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>3.292417</td>\n",
" <td>3.953279</td>\n",
" <td>0.382756</td>\n",
" <td>52.105968</td>\n",
" <td>5.703377</td>\n",
" <td>02:46</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>3.169972</td>\n",
" <td>3.935873</td>\n",
" <td>0.385500</td>\n",
" <td>51.206848</td>\n",
" <td>5.678265</td>\n",
" <td>02:47</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>3.099260</td>\n",
" <td>3.935462</td>\n",
" <td>0.385990</td>\n",
" <td>51.185768</td>\n",
" <td>5.677671</td>\n",
" <td>02:44</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Better model found at epoch 0 with valid_loss value: 5.275364875793457.\n",
"Better model found at epoch 1 with valid_loss value: 4.862295150756836.\n",
"Better model found at epoch 2 with valid_loss value: 4.540505886077881.\n",
"Better model found at epoch 3 with valid_loss value: 4.3044538497924805.\n",
"Better model found at epoch 4 with valid_loss value: 4.159061431884766.\n",
"Better model found at epoch 5 with valid_loss value: 4.056620121002197.\n",
"Better model found at epoch 6 with valid_loss value: 3.9893438816070557.\n",
"Better model found at epoch 7 with valid_loss value: 3.953279495239258.\n",
"Better model found at epoch 8 with valid_loss value: 3.93587327003479.\n",
"Better model found at epoch 9 with valid_loss value: 3.9354615211486816.\n"
]
}
],
"source": [
"admin_init(learn.model, scales)\n",
"learn.fit_one_cycle(n_epochs, 1e-3, cbs=[WandbCallback(log_model=False, log_preds=False)])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train SPU - 20e"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"SuggestedLRs(lr_min=0.04365158379077912, lr_steep=0.3630780577659607)"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.remove_cb(ActivationStats)\n",
"admin_init(learn.model, scales)\n",
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"WANDB_NAME = f'admin_init_spu_20e'\n",
"GROUP = 'admin'\n",
"NOTES = 'SentencePieceUnigram tokenizer trained on wiki 103, 20e, 32100 vcb_sz, Early ADMIN testing'\n",
"CONFIG = {}\n",
"TAGS = ['admin','spu']"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmorgan\u001b[0m (use `wandb login --relogin` to force relogin)\n"
]
},
{
"data": {
"text/html": [
"\n",
" Tracking run with wandb version 0.10.19<br/>\n",
" Syncing run <strong style=\"color:#cdcd00\">admin_init_spu_20e</strong> to <a href=\"https://wandb.ai\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
" Project page: <a href=\"https://wandb.ai/fastai_community/transformers_sandbox\" target=\"_blank\">https://wandb.ai/fastai_community/transformers_sandbox</a><br/>\n",
" Run page: <a href=\"https://wandb.ai/fastai_community/transformers_sandbox/runs/3ks1ewq1\" target=\"_blank\">https://wandb.ai/fastai_community/transformers_sandbox/runs/3ks1ewq1</a><br/>\n",
" Run data is saved locally in <code>/home/morgan/ml/projects/transformers_sandbox/nbs/exploration/wandb/run-20210221_193133-3ks1ewq1</code><br/><br/>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<h1>Run(3ks1ewq1)</h1><iframe src=\"https://wandb.ai/fastai_community/transformers_sandbox/runs/3ks1ewq1\" style=\"border:none;width:100%;height:400px\"></iframe>"
],
"text/plain": [
"<wandb.sdk.wandb_run.Run at 0x7fc180055d30>"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#hide_output\n",
"wandb.init(reinit=True, project=\"transformers_sandbox\", entity=\"fastai_community\", \n",
" name=WANDB_NAME, group=GROUP, notes=NOTES, tags=TAGS, config=CONFIG)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>perplexity</th>\n",
" <th>bpc</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>5.516002</td>\n",
" <td>5.488830</td>\n",
" <td>0.231003</td>\n",
" <td>241.973953</td>\n",
" <td>7.918708</td>\n",
" <td>01:41</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>4.964672</td>\n",
" <td>4.980829</td>\n",
" <td>0.285725</td>\n",
" <td>145.595062</td>\n",
" <td>7.185818</td>\n",
" <td>01:43</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>4.570927</td>\n",
" <td>4.636693</td>\n",
" <td>0.315706</td>\n",
" <td>103.202446</td>\n",
" <td>6.689333</td>\n",
" <td>01:45</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>4.294974</td>\n",
" <td>4.476145</td>\n",
" <td>0.328804</td>\n",
" <td>87.895203</td>\n",
" <td>6.457713</td>\n",
" <td>01:44</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>4.072177</td>\n",
" <td>4.286242</td>\n",
" <td>0.347858</td>\n",
" <td>72.692741</td>\n",
" <td>6.183739</td>\n",
" <td>01:46</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>3.882438</td>\n",
" <td>4.174093</td>\n",
" <td>0.354857</td>\n",
" <td>64.980888</td>\n",
" <td>6.021944</td>\n",
" <td>01:43</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>3.688582</td>\n",
" <td>4.089933</td>\n",
" <td>0.362383</td>\n",
" <td>59.735886</td>\n",
" <td>5.900526</td>\n",
" <td>01:42</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>3.555619</td>\n",
" <td>4.045366</td>\n",
" <td>0.370683</td>\n",
" <td>57.132111</td>\n",
" <td>5.836230</td>\n",
" <td>01:44</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>3.410418</td>\n",
" <td>4.018274</td>\n",
" <td>0.375051</td>\n",
" <td>55.605038</td>\n",
" <td>5.797144</td>\n",
" <td>01:42</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>3.273482</td>\n",
" <td>3.989733</td>\n",
" <td>0.379958</td>\n",
" <td>54.040470</td>\n",
" <td>5.755968</td>\n",
" <td>01:46</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>3.123741</td>\n",
" <td>3.992993</td>\n",
" <td>0.383907</td>\n",
" <td>54.216911</td>\n",
" <td>5.760671</td>\n",
" <td>01:41</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>2.994096</td>\n",
" <td>3.990759</td>\n",
" <td>0.385993</td>\n",
" <td>54.095951</td>\n",
" <td>5.757449</td>\n",
" <td>01:42</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>2.877162</td>\n",
" <td>4.005277</td>\n",
" <td>0.387474</td>\n",
" <td>54.887009</td>\n",
" <td>5.778393</td>\n",
" <td>01:42</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>2.746851</td>\n",
" <td>4.023293</td>\n",
" <td>0.390281</td>\n",
" <td>55.884834</td>\n",
" <td>5.804385</td>\n",
" <td>01:45</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>2.615041</td>\n",
" <td>4.046219</td>\n",
" <td>0.392509</td>\n",
" <td>57.180866</td>\n",
" <td>5.837461</td>\n",
" <td>01:44</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>2.524932</td>\n",
" <td>4.072429</td>\n",
" <td>0.393079</td>\n",
" <td>58.699383</td>\n",
" <td>5.875273</td>\n",
" <td>01:45</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16</td>\n",
" <td>2.434323</td>\n",
" <td>4.087971</td>\n",
" <td>0.392887</td>\n",
" <td>59.618816</td>\n",
" <td>5.897696</td>\n",
" <td>01:42</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17</td>\n",
" <td>2.361165</td>\n",
" <td>4.097982</td>\n",
" <td>0.393126</td>\n",
" <td>60.218670</td>\n",
" <td>5.912139</td>\n",
" <td>01:46</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>2.310880</td>\n",
" <td>4.103701</td>\n",
" <td>0.393512</td>\n",
" <td>60.564026</td>\n",
" <td>5.920389</td>\n",
" <td>01:43</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>2.289599</td>\n",
" <td>4.104567</td>\n",
" <td>0.393593</td>\n",
" <td>60.616467</td>\n",
" <td>5.921638</td>\n",
" <td>01:44</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Better model found at epoch 0 with valid_loss value: 5.488830089569092.\n",
"Better model found at epoch 1 with valid_loss value: 4.980829238891602.\n",
"Better model found at epoch 2 with valid_loss value: 4.636692523956299.\n",
"Better model found at epoch 3 with valid_loss value: 4.476145267486572.\n",
"Better model found at epoch 4 with valid_loss value: 4.28624153137207.\n",
"Better model found at epoch 5 with valid_loss value: 4.174093246459961.\n",
"Better model found at epoch 6 with valid_loss value: 4.089932918548584.\n",
"Better model found at epoch 7 with valid_loss value: 4.045366287231445.\n",
"Better model found at epoch 8 with valid_loss value: 4.018273830413818.\n",
"Better model found at epoch 9 with valid_loss value: 3.9897332191467285.\n"
]
}
],
"source": [
"learn.remove_cb(ActivationStats)\n",
"admin_init(learn.model, scales)\n",
"learn.fit_one_cycle(20, 1e-3, cbs=[WandbCallback(log_model=False, log_preds=False)])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Train - T5 Tokenizer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"T5 SentencePiece BPE Pretrained Tokenizer"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"from transformers import T5TokenizerFast\n",
"\n",
"class T5TokenizerTransform(Transform):\n",
" def __init__(self):\n",
" self.tok = T5TokenizerFast.from_pretrained('t5-small')\n",
" self.pad_token_id = self.tok.pad_token_id\n",
" self.eos_token_id = self.tok.eos_token_id\n",
" self.bos_token_id = self.tok.bos_token_id\n",
" self.unk_token_id = self.tok.unk_token_id\n",
" self.vcb_size = self.tok.vocab_size\n",
" \n",
" def __call__(self, o, **kwargs):\n",
" return LMTensorText(self.tok(o)['input_ids'])\n",
" \n",
" def encodes(self, o):\n",
" return LMTensorText(self.tok(o)['input_ids'])\n",
" \n",
" def decodes(self, o):\n",
" return TitledStr(self.tok.decode(o.numpy()))\n",
" \n",
" @property\n",
" def vocab_size(self): return self.vcb_size"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"t5_tok = T5TokenizerTransform()\n",
"tok = t5_tok"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading cached processed dataset at /home/morgan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-bd8ef24658a2bf53.arrow\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4aba2cec8c844241ae3221804c011c4a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=2461.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Token indices sequence length is longer than the specified maximum sequence length for this model (585 > 512). Running this sequence through the model will result in indexing errors\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"def encode(examples): return {'token_ids' : t5_tok(examples['text'])}\n",
"\n",
"tok_ds = train_ds.map(encode, batched=False)\n",
"val_tok_ds = valid_ds.map(encode, batched=False)\n",
"all_toks = tok_ds['token_ids'] + val_tok_ds['token_ids']\n",
"df['toks'] = all_toks "
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"#Run setup\n",
"n_epochs = 10\n",
"bs = 12\n",
"sl = 512\n",
"n_layers = 6\n",
"pad_id = tok.pad_token_id\n",
"vocab_size = tok.vocab_size"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>text_</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>It has been estimated that average @-@ sized bacteria contain about 2 million proteins per cell ( e.g. E. coli and Staphylococcus aureus ). Smaller bacteria, such as Mycoplasma or spirochetes contain fewer molecules, namely on the order of 50 @,@ 000 to 1 million. By contrast, eukaryotic cells are larger and thus contain much more protein. For instance, yeast cells were estimated to contain about 50 million proteins and human cells on the order of 1 to 3 billion. Note that bacterial genomes encode about 10 times fewer proteins than humans ( e.g. small bacteria &lt;unk&gt; 1 @,@ 000, E. coli : &lt;unk&gt; 4 @,@ 000, yeast : &lt;unk&gt; 6 @,@ 000, human : &lt;unk&gt; 20 @,@ 000 ).&lt;/s&gt; Even though cadmium and its compounds are toxic in certain forms and concentrations, the British Pharmaceutical Codex from 1907 states that cadmium iodide was used as a medication to treat</td>\n",
" <td>has been estimated that average @-@ sized bacteria contain about 2 million proteins per cell ( e.g. E. coli and Staphylococcus aureus ). Smaller bacteria, such as Mycoplasma or spirochetes contain fewer molecules, namely on the order of 50 @,@ 000 to 1 million. By contrast, eukaryotic cells are larger and thus contain much more protein. For instance, yeast cells were estimated to contain about 50 million proteins and human cells on the order of 1 to 3 billion. Note that bacterial genomes encode about 10 times fewer proteins than humans ( e.g. small bacteria &lt;unk&gt; 1 @,@ 000, E. coli : &lt;unk&gt; 4 @,@ 000, yeast : &lt;unk&gt; 6 @,@ 000, human : &lt;unk&gt; 20 @,@ 000 ).&lt;/s&gt; Even though cadmium and its compounds are toxic in certain forms and concentrations, the British Pharmaceutical Codex from 1907 states that cadmium iodide was used as a medication to treat \"</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>, collided with the back of the Renault, knocking the McLaren's front wing off the car. Suffering handling difficulties, Hamilton returned to the pit @-@ lane for a new nose section, and rejoined in 18th place. Räikkönen took second place when he passed Kubica on lap three ; Heidfeld took fourth when he passed Trulli and Kovalainen in separate manoeuvres. Further down the field, Vettel retired from the race on the first lap after twice colliding with other cars ; Button, Sutil and Coulthard pitted to repair early damage.&lt;/s&gt; = = Wheelchair basketball = =&lt;/s&gt; Jardine then ordered his team to move to bodyline positions immediately after Woodfull's injury. Jardine wrote that Larwood had asked for the field, while Larwood said that it was Jardine's decision. The capacity Saturday afternoon crowd viewed this as hitting a man when he was down. Journalist – cricketer Dick Whitington wrote that Jardine's actions</td>\n",
" <td>collided with the back of the Renault, knocking the McLaren's front wing off the car. Suffering handling difficulties, Hamilton returned to the pit @-@ lane for a new nose section, and rejoined in 18th place. Räikkönen took second place when he passed Kubica on lap three ; Heidfeld took fourth when he passed Trulli and Kovalainen in separate manoeuvres. Further down the field, Vettel retired from the race on the first lap after twice colliding with other cars ; Button, Sutil and Coulthard pitted to repair early damage.&lt;/s&gt; = = Wheelchair basketball = =&lt;/s&gt; Jardine then ordered his team to move to bodyline positions immediately after Woodfull's injury. Jardine wrote that Larwood had asked for the field, while Larwood said that it was Jardine's decision. The capacity Saturday afternoon crowd viewed this as hitting a man when he was down. Journalist – cricketer Dick Whitington wrote that Jardine's actions were</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"df['lens'] = df['toks'].apply(len)\n",
"splits = ColSplitter()(df)\n",
"tfms = [attrgetter(\"text\"), t5_tok]\n",
"dsets = Datasets(df, [tfms], splits=splits, dl_type=LMDataLoader)\n",
"dl_kwargs = [{'lens':df['lens'].values[splits[0]]},\n",
" {'val_lens':df['lens'].values[splits[1]]}]\n",
"dls = dsets.dataloaders(bs=bs, seq_len=sl, dl_kwargs=dl_kwargs, shuffle_train=True)\n",
"dls.show_batch(max_n=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### LR Find"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"SuggestedLRs(lr_min=0.05248074531555176, lr_steep=0.43651583790779114)"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.remove_cb(ActivationStats)\n",
"admin_init(learn.model, scales)\n",
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"WANDB_NAME = f'admin_init_t5_20e'\n",
"GROUP = 'admin'\n",
"NOTES = 'T5Tokenizer tokenizer, 20e, 32100 vcb_sz, Early ADMIN testing'\n",
"CONFIG = {}\n",
"TAGS = ['admin','t5']"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmorgan\u001b[0m (use `wandb login --relogin` to force relogin)\n"
]
},
{
"data": {
"text/html": [
"\n",
" Tracking run with wandb version 0.10.19<br/>\n",
" Syncing run <strong style=\"color:#cdcd00\">admin_init_t5_20e</strong> to <a href=\"https://wandb.ai\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
" Project page: <a href=\"https://wandb.ai/fastai_community/transformers_sandbox\" target=\"_blank\">https://wandb.ai/fastai_community/transformers_sandbox</a><br/>\n",
" Run page: <a href=\"https://wandb.ai/fastai_community/transformers_sandbox/runs/2k8d2xge\" target=\"_blank\">https://wandb.ai/fastai_community/transformers_sandbox/runs/2k8d2xge</a><br/>\n",
" Run data is saved locally in <code>/home/morgan/ml/projects/transformers_sandbox/nbs/exploration/wandb/run-20210221_220517-2k8d2xge</code><br/><br/>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<h1>Run(2k8d2xge)</h1><iframe src=\"https://wandb.ai/fastai_community/transformers_sandbox/runs/2k8d2xge\" style=\"border:none;width:100%;height:400px\"></iframe>"
],
"text/plain": [
"<wandb.sdk.wandb_run.Run at 0x7f58280433a0>"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#hide_output\n",
"wandb.init(reinit=True, project=\"transformers_sandbox\", entity=\"fastai_community\", \n",
" name=WANDB_NAME, group=GROUP, notes=NOTES, tags=TAGS, config=CONFIG)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
" </style>\n",
" <progress value='0' class='' max='20' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 0.00% [0/20 00:00<00:00]\n",
" </div>\n",
" \n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>perplexity</th>\n",
" <th>bpc</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table><p>\n",
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
" </style>\n",
" <progress value='8' class='' max='466' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 1.72% [8/466 00:01<01:39 9.5600]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.remove_cb(ActivationStats)\n",
"admin_init(learn.model, scales)\n",
"learn.fit_one_cycle(20, 1e-3, cbs=[WandbCallback(log_model=False, log_preds=False)])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Converted 00_core.ipynb.\n",
"Converted 01_layers.ipynb.\n",
"Converted 02_attention.ipynb.\n",
"Converted 03_models.transformer.ipynb.\n",
"Converted 04a_models.reformer.ipynb.\n",
"Converted 04x_models.xtransformer.ipynb.\n",
"Converted 05_tokenizers.ipynb.\n",
"Converted 06_data.ipynb.\n",
"Converted 07_metrics.ipynb.\n",
"Converted 08_optimizers.ipynb.\n",
"Converted 09_tracking.ipynb.\n",
"Converted 10_config.ipynb.\n",
"Converted index.ipynb.\n"
]
}
],
"source": [
"#hide\n",
"from nbdev.export import notebook2script; notebook2script()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"#default_exp xtransformer"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"from nbdev.showdoc import *"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"#!pip install datasets"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"from fastai.basics import *\n",
"from fastai.text.all import *\n",
"from fastai.callback import *\n",
"from fastai.callback.wandb import *\n",
"\n",
"from transformers_sandbox.core import *\n",
"from transformers_sandbox.layers import *\n",
"from transformers_sandbox.attention import *\n",
"from transformers_sandbox.tokenizers import ByteTextTokenizer, SPUniTokenizer\n",
"from transformers_sandbox.metrics import bpc\n",
"from transformers_sandbox.transformer import LMMixin, EncDecMixin, TransformerLM\n",
"from transformers_sandbox.config import ConfigBase, update_sig\n",
"from transformers_sandbox.xtransformer import *\n",
"\n",
"import wandb\n",
"from datasets import load_dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ADMIN Init\n",
"\n",
"> ADMIN Initialisation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"📕 **Paper**: https://arxiv.org/abs/2004.08249\n",
"\n",
"Transformers have proved effective in many NLP tasks. However, their training requires non-trivial efforts regarding designing cutting-edge optimizers and learning rate schedulers carefully (e.g., conventional SGD fails to train Transformers effectively). Our objective here is to understand what complicates Transformer training from both empirical and theoretical perspectives. \n",
"\n",
"Our analysis reveals that unbalanced gradients are not the root cause of the instability of training. Instead, we identify an amplification effect that influences training substantially -- for each layer in a multi-layer Transformer model, heavy dependency on its residual branch makes training unstable, since it amplifies small parameter perturbations (e.g., parameter updates) and results in significant disturbances in the model output. \n",
"\n",
"Yet we observe that a light dependency limits the model potential and leads to inferior trained models. Inspired by our analysis, we propose Admin (**Ad**aptive **m**odel **in**itialization) to stabilize stabilize the early stage's training and unleash its full potential in the late stage. Extensive experiments show that Admin is more stable, converges faster, and leads to better performance. Implementations are released at: [this https URL](https://github.com/LiyuanLucasLiu/Transforemr-Clinic)."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class AdminResidual(Module):\n",
" def __init__(self, sublayer, d_model):\n",
" self.sublayer = sublayer\n",
" self.w = torch.nn.Parameter(torch.ones(d_model))\n",
" def forward(self, x, *args, **kwargs):\n",
" return x*self.w + self.sublayer(x, *args, **kwargs)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class TransformerEncoderBlockAdmin(Module):\n",
" \"\"\"\n",
" Bacis transformer encoder block. Consists of multi-head attention and positional \n",
" feedforward layers\n",
" \"\"\"\n",
" def __init__(self,\n",
" d_model:int, \n",
" n_heads:int = 8, \n",
" d_ff:int = None, \n",
" attn_dropout:float = 0.1,\n",
" ff_dropout:float = 0.1,\n",
" causal:bool = False, \n",
" attn_bias:bool = False, \n",
" prenorm:bool=False,\n",
" shared_qk:bool=False):\n",
" store_attr('attn_dropout') # mb separate argument attn_post_dropout\n",
" \n",
" self.attn = PostNorm(d_model, AdminResidual(Attention(d_model, n_heads=n_heads, causal=causal, dropout=attn_dropout, bias=attn_bias, shared_qk=shared_qk), d_model))\n",
" self.ff = PostNorm(d_model, AdminResidual(FeedForward(d_model, d_ff=d_ff, dropout=ff_dropout), d_model))\n",
" \n",
" def forward(self, x, mask=None):\n",
" out = self.attn(x, mask=mask)\n",
" return self.ff(out)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# bs = 4\n",
"# sl = 128\n",
"# d = 64\n",
"# x = torch.randn(bs, sl, d)\n",
"# m = TransformerEncoderBlockAdmin(d)\n",
"# out = m(x)\n",
"# assert (out.size() == (bs, sl, d))\n",
"# out.shape"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class TransformerEncoderAdmin(Module):\n",
" \"\"\"Stack of TransformerEncoderBlocks\"\"\"\n",
" def __init__(self, \n",
" d_model, \n",
" n_layers=6, \n",
" n_heads=8, \n",
" d_ff=None,\n",
" ff_dropout=0.1, \n",
" attn_dropout=0.1,\n",
" attn_bias=False,\n",
" causal=False, \n",
" prenorm=False,\n",
" shared_qk:bool=False,\n",
" final_norm=None):\n",
" store_attr('d_model')\n",
" self.layers = nn.ModuleList([]) \n",
" for _ in range(n_layers):\n",
" self.layers.append(TransformerEncoderBlockAdmin(d_model, n_heads, causal=causal, \n",
" d_ff=d_ff, attn_dropout=attn_dropout, ff_dropout=ff_dropout, \n",
" prenorm=prenorm, attn_bias=attn_bias, shared_qk=shared_qk))\n",
" self.norm = None if final_norm is None else final_norm(d_model)\n",
" \n",
" def forward(self, x, mask=None):\n",
" for layer in self.layers: x = layer(x, mask=mask)\n",
" if self.norm is not None: x = self.norm(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class TransformerLMAdmin(Module, LMMixin):\n",
" \"\"\"\n",
" tmp\n",
" Basic Transformer for language modelling\n",
" \n",
" Parameters:\n",
" * vocab_sz: int\n",
" * d_model: int - inner dimension of the model\n",
" * n_layers: int (default: 6) \n",
" * n_heads: int (default: 8)\n",
" * d_ff: int - inner dimension of the pointwise FeedForward net, if None defaults to 4*d_model\n",
" * attn_dropout: float - attention dropout\n",
" * ff_dropout: float - feed-forward dropout\n",
" * emb_dropout: float - embedding dropout\n",
" * causal: bool (default: True) - if True does causal masking automatically\n",
" * max_seq_len: int (default: 512)\n",
" * tie_weights: bool - if True target embedding weights are used for computation output projection\n",
" * prenorm: bool - wether to use PreNorm or PostNorm\n",
" * attn_bias: bool - wether to allow biases in attention projection layers\n",
" * pad_idx: int - padding token id, required for autogeneration of padding mask\n",
" * pos_enc: str from {'absolute', 'fixed', 'axial'} - type of positional encoding to use\n",
" * axial_shape: tuple - [optional] should be factors of max_seq_len\n",
" * axial_emb_dims: tuple - [optional] axial embedding components, should sum to d_model\n",
" Inputs:\n",
" * x - input ids, shape [bs, sl]\n",
" * mask - optional boolean mask, shape [bs, sl]\n",
" Returns:\n",
" * logits - target token logits, shape [bs, sl, vocab_sz]\n",
" \"\"\"\n",
" def __init__(self, \n",
" vocab_sz:int, \n",
" d_model:int, \n",
" n_layers:int=6,\n",
" n_heads:int=8,\n",
" d_ff:int=None,\n",
" attn_dropout:float=0.1,\n",
" ff_dropout:float=0.1,\n",
" emb_dropout:float=0.1,\n",
" tie_weights:bool=True,\n",
" causal:bool=True,\n",
" pos_enc:str='absolute',\n",
" max_seq_len:int=512,\n",
" axial_shape:tuple=None,\n",
" axial_emb_dims:tuple=None,\n",
" pad_idx:int=None,\n",
" prenorm:bool=False,\n",
" attn_bias:bool=False,\n",
" shared_qk:bool=False):\n",
" store_attr()\n",
" self.emb = TransformerEmbedding(vocab_sz, d_model, max_seq_len, dropout=emb_dropout, \n",
" pos_enc=pos_enc, axial_shape=axial_shape, \n",
" axial_emb_dims=axial_emb_dims)\n",
" final_norm = None\n",
" self.encoder = TransformerEncoderAdmin(d_model, n_layers, n_heads, causal=causal, d_ff=d_ff,\n",
" attn_dropout=attn_dropout, ff_dropout=ff_dropout,\n",
" prenorm=prenorm, attn_bias=attn_bias,\n",
" shared_qk=shared_qk, final_norm=final_norm)\n",
" self.proj = nn.Linear(d_model, vocab_sz)\n",
" if tie_weights: self.proj.weight = self.emb.emb.weight\n",
" \n",
" def forward(self, x, mask=None):\n",
" x = self.emb(x)\n",
" x = self.encoder(x, mask=mask)\n",
" return self.proj(x)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"@patch(cls_method=True)\n",
"def from_config(cls:TransformerLMAdmin, config):\n",
" return cls(**config.dict())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Reusing dataset wikitext (/home/morgan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91)\n",
"Loading cached processed dataset at /home/morgan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-0b8034a134b9567b.arrow\n",
"Reusing dataset wikitext (/home/morgan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91)\n",
"Loading cached processed dataset at /home/morgan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-d23112255b6f8a24.arrow\n"
]
}
],
"source": [
"train_ds = load_dataset('wikitext', name='wikitext-2-raw-v1', split='train')\n",
"\n",
"train_ds = train_ds.filter(lambda x: x['text'] != '')\n",
"train_df = train_ds.data.to_pandas()\n",
"\n",
"valid_ds = load_dataset('wikitext', name='wikitext-2-raw-v1', split='validation')\n",
"\n",
"valid_ds = valid_ds.filter(lambda x: x['text'] != '')\n",
"valid_df = valid_ds.data.to_pandas()\n",
"\n",
"df = pd.concat([train_df, valid_df])\n",
"df['is_valid'] = [False]*len(train_df) + [True]*len(valid_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tokenizer + DataLoaders"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ByteTextTokenizer"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"tok = ByteTextTokenizer(is_lm=True, add_bos=False, add_eos=False)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"#Run setup\n",
"n_epochs = 10\n",
"bs = 16\n",
"sl = 512\n",
"n_layers = 6\n",
"pad_id = tok.pad_token_id\n",
"vocab_size = tok.vocab_size"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"df['toks'] = df['text'].apply(tok)\n",
"df['lens'] = df['toks'].apply(len)\n",
"splits = ColSplitter()(df)\n",
"tfms = [attrgetter(\"text\"), tok]\n",
"dsets = Datasets(df, [tfms], splits=splits, dl_type=LMDataLoader)\n",
"dl_kwargs = [{'lens':df['lens'].values[splits[0]]},\n",
" {'val_lens':df['lens'].values[splits[1]]}]\n",
"dls = dsets.dataloaders(bs=bs, seq_len=sl, dl_kwargs=dl_kwargs, shuffle_train=True, n_workers=2)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>text_</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>The film is set in 2034 , two years after the events of 2nd GIG . Togusa is now the team leader for Public Security Section 9 , which has increased considerably in size . Section 9 deals with a series of complicated incidents , including the assassination of Ka Rum , a former dictator of the Siak Republic , which leads to a terrorist plot using children as vectors for a cybernetic virus . Investigations reveal that a hacker nicknamed the \" The Puppeteer \" is behind the entire series of events . \\n Mozambiqu</td>\n",
" <td>The film is set in 2034 , two years after the events of 2nd GIG . Togusa is now the team leader for Public Security Section 9 , which has increased considerably in size . Section 9 deals with a series of complicated incidents , including the assassination of Ka Rum , a former dictator of the Siak Republic , which leads to a terrorist plot using children as vectors for a cybernetic virus . Investigations reveal that a hacker nicknamed the \" The Puppeteer \" is behind the entire series of events . \\n Mozambique</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>združenje ) its association with Yugoslavia on 25 June 1991 . The European Economic Community and the Conference on Security and Cooperation in Europe urged Croatian authorities to place a three @-@ month moratorium on the decision . Croatia agreed to freeze its independence declaration for three months , initially easing tensions . Nonetheless , the Croatian War of Independence escalated further . On 7 October , the eve of expiration of the moratorium , the Yugoslav Air Force attacked Banski dvori , the m</td>\n",
" <td>druženje ) its association with Yugoslavia on 25 June 1991 . The European Economic Community and the Conference on Security and Cooperation in Europe urged Croatian authorities to place a three @-@ month moratorium on the decision . Croatia agreed to freeze its independence declaration for three months , initially easing tensions . Nonetheless , the Croatian War of Independence escalated further . On 7 October , the eve of expiration of the moratorium , the Yugoslav Air Force attacked Banski dvori , the ma</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(16, None)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#collapse_output\n",
"dls.bs, dls.show_batch(max_n=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"#config = XConfig(vocab_sz=vocab_size, d_model=512, n_layers=n_layers, max_seq_len=512, pad_idx=pad_id)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"class CharLMConfig(ConfigBase):\n",
" _model = TransformerLM\n",
" _d = {\n",
" 'vocab_sz':256,\n",
" 'd_model':512,\n",
" 'n_layers':6,\n",
" 'n_heads':8,\n",
" 'd_ff':4096,\n",
" 'attn_dropout':0.1,\n",
" 'ff_dropout':0.1,\n",
" 'emb_dropout':0.1,\n",
" 'tie_weights':True,\n",
" 'causal':True,\n",
" 'pos_enc':'absolute',\n",
" 'max_seq_len':512,\n",
" 'axial_shape':None,\n",
" 'axial_emb_dims':None,\n",
" 'pad_idx':None,\n",
" 'prenorm':False,\n",
" 'attn_bias':False,\n",
" 'shared_qk':False,\n",
" }\n",
" @update_sig(_d)\n",
" def __init__(self, **kwargs):\n",
" super().__init__(**kwargs)\n",
"\n",
"config = CharLMConfig(vocab_sz=vocab_size, d_model=512, n_layers=n_layers, \n",
" max_seq_len=512, pad_idx=pad_id)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Learner"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(dls, TransformerLMAdmin.from_config(config),\n",
" loss_func=CrossEntropyLossFlat(ignore_index=pad_id), #opt_func=Adam,\n",
" cbs = [\n",
" #GradientClip(1.0),\n",
" SaveModelCallback(with_opt=True)],\n",
" metrics=[accuracy, perplexity, bpc]).to_fp16()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"class BreakFitCallback(Callback):\n",
" order=-1\n",
" \"Cancels fit after one batch before weight update\"\n",
" def before_step(self):\n",
" self.model.zero_grad(set_to_none=True)\n",
" raise CancelStepException\n",
" def after_step(self):\n",
" raise CancelBatchException\n",
" def after_batch(self):\n",
" print('Fit canceled')\n",
" raise CancelFitException"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"def res_submodules(model):\n",
" return [m.sublayer for m in learn.model.modules() if isinstance(m, AdminResidual)]\n",
"\n",
"def res_modules(model):\n",
" return [m for m in learn.model.modules() if isinstance(m, AdminResidual)]"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"learn.add_cb(ActivationStats(modules=res_submodules(learn.model)));"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"12"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(learn.activation_stats.modules)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train_eval -10\n",
"recorder 50\n",
"progress 60\n",
"save_model 60\n",
"mixed_precision 10\n",
"activation_stats -20\n"
]
}
],
"source": [
"for cb in learn.cbs:\n",
" print(cb.name, cb.order)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>perplexity</th>\n",
" <th>bpc</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fit canceled\n"
]
}
],
"source": [
"with learn.added_cbs(BreakFitCallback()), learn.removed_cbs(SaveModelCallback):\n",
" learn.fit(1, 1e-3)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#12) [{'mean': 0.00481663690879941, 'std': 0.3442692756652832, 'near_zero': 0.6015446980794271},{'mean': -0.0025428251828998327, 'std': 0.3987179100513458, 'near_zero': 0.5950142542521158},{'mean': 0.004401894751936197, 'std': 0.19933182001113892, 'near_zero': 0.6423584620157877},{'mean': 0.003396979533135891, 'std': 0.4013838768005371, 'near_zero': 0.5894660949707031},{'mean': 0.009360421448946, 'std': 0.26095426082611084, 'near_zero': 0.6073331832885742},{'mean': -0.0018820768455043435, 'std': 0.3994253873825073, 'near_zero': 0.5938955942789713},{'mean': -0.02019776962697506, 'std': 0.3356610834598541, 'near_zero': 0.6348047256469727},{'mean': 0.013347254134714603, 'std': 0.40092015266418457, 'near_zero': 0.5801426569620768},{'mean': -0.024423303082585335, 'std': 0.3910963237285614, 'near_zero': 0.621729850769043},{'mean': 0.005235261749476194, 'std': 0.3965536057949066, 'near_zero': 0.587321917215983}...]"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.activation_stats.stats[0]"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.11852133, 0.15897597, 0.03973317, 0.16110902, 0.06809713,\n",
" 0.15954064, 0.11266836, 0.16073697, 0.15295633, 0.15725476,\n",
" 0.24887388, 0.15945325])"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def variances(learn):\n",
" return np.array([stat['std']**2 for stat in learn.activation_stats.stats[0]])\n",
"\n",
"variances(learn)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.34426928, 0.52678013, 0.56323217, 0.69162092, 0.73921352,\n",
" 0.84022453, 0.90479038, 0.98963761, 1.06411415, 1.13560279,\n",
" 1.24034978, 1.30304291])"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def _init_scales(vars):\n",
" return np.sqrt(np.cumsum(vars))\n",
"scales = _init_scales(variances(learn))\n",
"scales"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define ADMIN Initialisation"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"def admin_init(model, scales):\n",
" ms = res_modules(model)\n",
" for m, s in zip(ms, scales):\n",
" m.w.data *= s"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## LR Find"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"SuggestedLRs(lr_min=0.03019951581954956, lr_steep=0.25118863582611084)"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.remove_cb(ActivationStats)\n",
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"SuggestedLRs(lr_min=0.04365158379077912, lr_steep=0.3630780577659607)"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"admin_init(learn.model, scales)\n",
"learn.lr_find()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tracking"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmorgan\u001b[0m (use `wandb login --relogin` to force relogin)\n"
]
}
],
"source": [
"# hide\n",
"!wandb login"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load Experiment Tracking with Weights & Biases:"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"import wandb\n",
"\n",
"WANDB_NAME = f'admin_init_bte2'\n",
"GROUP = 'admin'\n",
"NOTES = 'Early ADMIN testing'\n",
"CONFIG = {}\n",
"TAGS = ['admin']"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmorgan\u001b[0m (use `wandb login --relogin` to force relogin)\n"
]
},
{
"data": {
"text/html": [
"\n",
" Tracking run with wandb version 0.10.19<br/>\n",
" Syncing run <strong style=\"color:#cdcd00\">admin_init_bte2</strong> to <a href=\"https://wandb.ai\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
" Project page: <a href=\"https://wandb.ai/fastai_community/transformers_sandbox\" target=\"_blank\">https://wandb.ai/fastai_community/transformers_sandbox</a><br/>\n",
" Run page: <a href=\"https://wandb.ai/fastai_community/transformers_sandbox/runs/2o27v7d4\" target=\"_blank\">https://wandb.ai/fastai_community/transformers_sandbox/runs/2o27v7d4</a><br/>\n",
" Run data is saved locally in <code>/home/morgan/ml/projects/transformers_sandbox/nbs/exploration/wandb/run-20210220_234952-2o27v7d4</code><br/><br/>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<h1>Run(2o27v7d4)</h1><iframe src=\"https://wandb.ai/fastai_community/transformers_sandbox/runs/2o27v7d4\" style=\"border:none;width:100%;height:400px\"></iframe>"
],
"text/plain": [
"<wandb.sdk.wandb_run.Run at 0x7fb00ef616a0>"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#hide_output\n",
"wandb.init(reinit=True, project=\"transformers_sandbox\", entity=\"fastai_community\", \n",
" name=WANDB_NAME, group=GROUP, notes=NOTES, tags=TAGS, config=CONFIG)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>perplexity</th>\n",
" <th>bpc</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>2.117492</td>\n",
" <td>2.035515</td>\n",
" <td>0.404338</td>\n",
" <td>7.656197</td>\n",
" <td>2.936628</td>\n",
" <td>07:32</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.601835</td>\n",
" <td>1.523893</td>\n",
" <td>0.554345</td>\n",
" <td>4.590059</td>\n",
" <td>2.198513</td>\n",
" <td>07:29</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.452392</td>\n",
" <td>1.390810</td>\n",
" <td>0.589804</td>\n",
" <td>4.018106</td>\n",
" <td>2.006516</td>\n",
" <td>07:28</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1.347416</td>\n",
" <td>1.315846</td>\n",
" <td>0.613917</td>\n",
" <td>3.727902</td>\n",
" <td>1.898364</td>\n",
" <td>07:29</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>1.275978</td>\n",
" <td>1.247816</td>\n",
" <td>0.633069</td>\n",
" <td>3.482728</td>\n",
" <td>1.800218</td>\n",
" <td>07:28</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>1.227183</td>\n",
" <td>1.200533</td>\n",
" <td>0.647163</td>\n",
" <td>3.321887</td>\n",
" <td>1.732003</td>\n",
" <td>07:27</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>1.188224</td>\n",
" <td>1.171363</td>\n",
" <td>0.655419</td>\n",
" <td>3.226386</td>\n",
" <td>1.689919</td>\n",
" <td>07:28</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>1.147837</td>\n",
" <td>1.153088</td>\n",
" <td>0.661562</td>\n",
" <td>3.167962</td>\n",
" <td>1.663555</td>\n",
" <td>07:28</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>1.123723</td>\n",
" <td>1.142603</td>\n",
" <td>0.665061</td>\n",
" <td>3.134917</td>\n",
" <td>1.648427</td>\n",
" <td>07:29</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>1.113917</td>\n",
" <td>1.141676</td>\n",
" <td>0.665857</td>\n",
" <td>3.132012</td>\n",
" <td>1.647090</td>\n",
" <td>07:28</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Better model found at epoch 0 with valid_loss value: 2.035515308380127.\n",
"Better model found at epoch 1 with valid_loss value: 1.523892879486084.\n",
"Better model found at epoch 2 with valid_loss value: 1.390810489654541.\n",
"Better model found at epoch 3 with valid_loss value: 1.3158457279205322.\n",
"Better model found at epoch 4 with valid_loss value: 1.2478158473968506.\n",
"Better model found at epoch 5 with valid_loss value: 1.2005329132080078.\n",
"Better model found at epoch 6 with valid_loss value: 1.1713625192642212.\n",
"Better model found at epoch 7 with valid_loss value: 1.1530883312225342.\n",
"Better model found at epoch 8 with valid_loss value: 1.142602562904358.\n",
"Better model found at epoch 9 with valid_loss value: 1.1416757106781006.\n"
]
}
],
"source": [
"admin_init(learn.model, scales)\n",
"learn.fit_one_cycle(n_epochs, 1e-3, cbs=[WandbCallback(log_model=False, log_preds=False)])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train - SPUniTokenizer\n",
"\n",
"Train a tokenizer with the same dictionary length as the T5 tokenizer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"SentencePiece Unigram Tokenizer"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Reusing dataset wikitext (/home/morgan/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91)\n",
"Loading cached processed dataset at /home/morgan/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-e9eccb98202e10a0.arrow\n"
]
}
],
"source": [
"big_train_ds = load_dataset('wikitext', name='wikitext-103-raw-v1', split='train')\n",
"big_train_ds = big_train_ds.filter(lambda x: x['text'] != '')\n",
"\n",
"spu_tok = SPUniTokenizer(vcb_sz=32093, add_eos=True, add_bos=True)\n",
"spu_tok.train(big_train_ds['text'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Save Trained Tokenizer"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"def save_obj(obj, name ):\n",
" with open(name + '.pkl', 'wb') as f:\n",
" pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)\n",
"\n",
"def load_obj(name ):\n",
" with open(name + '.pkl', 'rb') as f:\n",
" return pickle.load(f)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# save_obj(spu_tok)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"spu_tok = load_obj('sentencepieceunigram_wiki103_32093')"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"tok = spu_tok"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"#Run setup\n",
"n_epochs = 10\n",
"bs = 12\n",
"sl = 512\n",
"n_layers = 6\n",
"pad_id = tok.pad_token_id\n",
"vocab_size = tok.vocab_size"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Tokenize Data"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading cached processed dataset at /home/morgan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-4267024924ef223b.arrow\n",
"Loading cached processed dataset at /home/morgan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-916f06ba6ec602c9.arrow\n"
]
}
],
"source": [
"def encode(examples): return {'token_ids' : spu_tok(examples['text'])}\n",
"\n",
"tok_ds = train_ds.map(encode, batched=False)\n",
"val_tok_ds = valid_ds.map(encode, batched=False)\n",
"all_toks = tok_ds['token_ids'] + val_tok_ds['token_ids']\n",
"df['toks'] = all_toks "
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>text_</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Coulthard and Button collided on lap 18 when Button attempted to pass the Red Bull on the inside at turn eight ; the Honda lost its front wing and retired a lap later after two pit stops . Hamilton continued his climb back through the field ; he moved from 18th , passing Piquet , Davidson , Sutil and Bourdais in separate manoeuvres , to sit in 14th by the time he pitted on lap 31 . Piquet retired on lap 42 with transmission failure , requiring a gearbox change before the next race . Anderson was involved in filmmaking at a young age and never really had an alternative plan to directing films . He made his first movie when he was eight years old and started making movies on a Betamax video camera which his dad bought in 1982 when he was twelve years old . He</td>\n",
" <td>Coulthard and Button collided on lap 18 when Button attempted to pass the Red Bull on the inside at turn eight ; the Honda lost its front wing and retired a lap later after two pit stops . Hamilton continued his climb back through the field ; he moved from 18th , passing Piquet , Davidson , Sutil and Bourdais in separate manoeuvres , to sit in 14th by the time he pitted on lap 31 . Piquet retired on lap 42 with transmission failure , requiring a gearbox change before the next race . Anderson was involved in filmmaking at a young age and never really had an alternative plan to directing films . He made his first movie when he was eight years old and started making movies on a Betamax video camera which his dad bought in 1982 when he was twelve years old . He</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>subdue and molest his victims . The song ends with the narrator turning inward with the lyrics : \" And in my best behavior , I am really just like him / Look beneath the floorboards for the secrets I have hid . \" Stevens stated in a 2009 interview with Paste that \" we 're all capable of what [ Gacy ] did . \" Technically , it is known as a speaker that easily reveals poor quality in recordings . Recording engineers sought to dull its treble response by hanging tissue paper in front of it , resulting in what became known as the \" tissue paper effect \" , a type of comb filtering . The NS @-@ 10 has been used to monitor a large number of successful recordings by numerous artists , leading Gizmodo to refer to it as \" the most important loudspeaker</td>\n",
" <td>and molest his victims . The song ends with the narrator turning inward with the lyrics : \" And in my best behavior , I am really just like him / Look beneath the floorboards for the secrets I have hid . \" Stevens stated in a 2009 interview with Paste that \" we 're all capable of what [ Gacy ] did . \" Technically , it is known as a speaker that easily reveals poor quality in recordings . Recording engineers sought to dull its treble response by hanging tissue paper in front of it , resulting in what became known as the \" tissue paper effect \" , a type of comb filtering . The NS @-@ 10 has been used to monitor a large number of successful recordings by numerous artists , leading Gizmodo to refer to it as \" the most important loudspeaker you</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"df['lens'] = df['toks'].apply(len)\n",
"splits = ColSplitter()(df)\n",
"tfms = [attrgetter(\"text\"), spu_tok]\n",
"# tfms = [attrgetter(\"toks\")]\n",
"dsets = Datasets(df, [tfms], splits=splits, dl_type=LMDataLoader)\n",
"dl_kwargs = [{'lens':df['lens'].values[splits[0]]},\n",
" {'val_lens':df['lens'].values[splits[1]]}]\n",
"dls = dsets.dataloaders(bs=bs, seq_len=sl, dl_kwargs=dl_kwargs, shuffle_train=True)\n",
"dls.show_batch(max_n=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Tracking"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"WANDB_NAME = f'admin_init_spu'\n",
"GROUP = 'admin'\n",
"NOTES = 'SentencePieceUnigram tokenizer trained on wiki 103, 32100 vcb_sz, Early ADMIN testing'\n",
"CONFIG = {}\n",
"TAGS = ['admin','spu']"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmorgan\u001b[0m (use `wandb login --relogin` to force relogin)\n"
]
},
{
"data": {
"text/html": [
"\n",
" Tracking run with wandb version 0.10.19<br/>\n",
" Syncing run <strong style=\"color:#cdcd00\">admin_init_spu</strong> to <a href=\"https://wandb.ai\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
" Project page: <a href=\"https://wandb.ai/fastai_community/transformers_sandbox\" target=\"_blank\">https://wandb.ai/fastai_community/transformers_sandbox</a><br/>\n",
" Run page: <a href=\"https://wandb.ai/fastai_community/transformers_sandbox/runs/896aogrk\" target=\"_blank\">https://wandb.ai/fastai_community/transformers_sandbox/runs/896aogrk</a><br/>\n",
" Run data is saved locally in <code>/home/morgan/ml/projects/transformers_sandbox/nbs/exploration/wandb/run-20210221_183500-896aogrk</code><br/><br/>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<h1>Run(896aogrk)</h1><iframe src=\"https://wandb.ai/fastai_community/transformers_sandbox/runs/896aogrk\" style=\"border:none;width:100%;height:400px\"></iframe>"
],
"text/plain": [
"<wandb.sdk.wandb_run.Run at 0x7f5b92f19e50>"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#hide_output\n",
"wandb.init(reinit=True, project=\"transformers_sandbox\", entity=\"fastai_community\", \n",
" name=WANDB_NAME, group=GROUP, notes=NOTES, tags=TAGS, config=CONFIG)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train SPU"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>perplexity</th>\n",
" <th>bpc</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>5.331179</td>\n",
" <td>5.275365</td>\n",
" <td>0.253242</td>\n",
" <td>195.461777</td>\n",
" <td>7.610743</td>\n",
" <td>02:44</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>4.771585</td>\n",
" <td>4.862295</td>\n",
" <td>0.290732</td>\n",
" <td>129.320679</td>\n",
" <td>7.014809</td>\n",
" <td>02:43</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>4.419650</td>\n",
" <td>4.540506</td>\n",
" <td>0.322093</td>\n",
" <td>93.738213</td>\n",
" <td>6.550565</td>\n",
" <td>02:46</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>4.108190</td>\n",
" <td>4.304454</td>\n",
" <td>0.344261</td>\n",
" <td>74.028770</td>\n",
" <td>6.210014</td>\n",
" <td>02:44</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>3.865303</td>\n",
" <td>4.159061</td>\n",
" <td>0.357382</td>\n",
" <td>64.011414</td>\n",
" <td>6.000257</td>\n",
" <td>02:45</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>3.655769</td>\n",
" <td>4.056620</td>\n",
" <td>0.368161</td>\n",
" <td>57.778694</td>\n",
" <td>5.852466</td>\n",
" <td>02:44</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>3.459788</td>\n",
" <td>3.989344</td>\n",
" <td>0.376501</td>\n",
" <td>54.019436</td>\n",
" <td>5.755407</td>\n",
" <td>02:45</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>3.292417</td>\n",
" <td>3.953279</td>\n",
" <td>0.382756</td>\n",
" <td>52.105968</td>\n",
" <td>5.703377</td>\n",
" <td>02:46</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>3.169972</td>\n",
" <td>3.935873</td>\n",
" <td>0.385500</td>\n",
" <td>51.206848</td>\n",
" <td>5.678265</td>\n",
" <td>02:47</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>3.099260</td>\n",
" <td>3.935462</td>\n",
" <td>0.385990</td>\n",
" <td>51.185768</td>\n",
" <td>5.677671</td>\n",
" <td>02:44</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Better model found at epoch 0 with valid_loss value: 5.275364875793457.\n",
"Better model found at epoch 1 with valid_loss value: 4.862295150756836.\n",
"Better model found at epoch 2 with valid_loss value: 4.540505886077881.\n",
"Better model found at epoch 3 with valid_loss value: 4.3044538497924805.\n",
"Better model found at epoch 4 with valid_loss value: 4.159061431884766.\n",
"Better model found at epoch 5 with valid_loss value: 4.056620121002197.\n",
"Better model found at epoch 6 with valid_loss value: 3.9893438816070557.\n",
"Better model found at epoch 7 with valid_loss value: 3.953279495239258.\n",
"Better model found at epoch 8 with valid_loss value: 3.93587327003479.\n",
"Better model found at epoch 9 with valid_loss value: 3.9354615211486816.\n"
]
}
],
"source": [
"admin_init(learn.model, scales)\n",
"learn.fit_one_cycle(n_epochs, 1e-3, cbs=[WandbCallback(log_model=False, log_preds=False)])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train SPU - 20e"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"SuggestedLRs(lr_min=0.04365158379077912, lr_steep=0.3630780577659607)"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.remove_cb(ActivationStats)\n",
"admin_init(learn.model, scales)\n",
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"WANDB_NAME = f'admin_init_spu_20e'\n",
"GROUP = 'admin'\n",
"NOTES = 'SentencePieceUnigram tokenizer trained on wiki 103, 20e, 32100 vcb_sz, Early ADMIN testing'\n",
"CONFIG = {}\n",
"TAGS = ['admin','spu']"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmorgan\u001b[0m (use `wandb login --relogin` to force relogin)\n"
]
},
{
"data": {
"text/html": [
"\n",
" Tracking run with wandb version 0.10.19<br/>\n",
" Syncing run <strong style=\"color:#cdcd00\">admin_init_spu_20e</strong> to <a href=\"https://wandb.ai\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
" Project page: <a href=\"https://wandb.ai/fastai_community/transformers_sandbox\" target=\"_blank\">https://wandb.ai/fastai_community/transformers_sandbox</a><br/>\n",
" Run page: <a href=\"https://wandb.ai/fastai_community/transformers_sandbox/runs/3ks1ewq1\" target=\"_blank\">https://wandb.ai/fastai_community/transformers_sandbox/runs/3ks1ewq1</a><br/>\n",
" Run data is saved locally in <code>/home/morgan/ml/projects/transformers_sandbox/nbs/exploration/wandb/run-20210221_193133-3ks1ewq1</code><br/><br/>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<h1>Run(3ks1ewq1)</h1><iframe src=\"https://wandb.ai/fastai_community/transformers_sandbox/runs/3ks1ewq1\" style=\"border:none;width:100%;height:400px\"></iframe>"
],
"text/plain": [
"<wandb.sdk.wandb_run.Run at 0x7fc180055d30>"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#hide_output\n",
"wandb.init(reinit=True, project=\"transformers_sandbox\", entity=\"fastai_community\", \n",
" name=WANDB_NAME, group=GROUP, notes=NOTES, tags=TAGS, config=CONFIG)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>perplexity</th>\n",
" <th>bpc</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>5.516002</td>\n",
" <td>5.488830</td>\n",
" <td>0.231003</td>\n",
" <td>241.973953</td>\n",
" <td>7.918708</td>\n",
" <td>01:41</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>4.964672</td>\n",
" <td>4.980829</td>\n",
" <td>0.285725</td>\n",
" <td>145.595062</td>\n",
" <td>7.185818</td>\n",
" <td>01:43</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>4.570927</td>\n",
" <td>4.636693</td>\n",
" <td>0.315706</td>\n",
" <td>103.202446</td>\n",
" <td>6.689333</td>\n",
" <td>01:45</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>4.294974</td>\n",
" <td>4.476145</td>\n",
" <td>0.328804</td>\n",
" <td>87.895203</td>\n",
" <td>6.457713</td>\n",
" <td>01:44</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>4.072177</td>\n",
" <td>4.286242</td>\n",
" <td>0.347858</td>\n",
" <td>72.692741</td>\n",
" <td>6.183739</td>\n",
" <td>01:46</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>3.882438</td>\n",
" <td>4.174093</td>\n",
" <td>0.354857</td>\n",
" <td>64.980888</td>\n",
" <td>6.021944</td>\n",
" <td>01:43</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>3.688582</td>\n",
" <td>4.089933</td>\n",
" <td>0.362383</td>\n",
" <td>59.735886</td>\n",
" <td>5.900526</td>\n",
" <td>01:42</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>3.555619</td>\n",
" <td>4.045366</td>\n",
" <td>0.370683</td>\n",
" <td>57.132111</td>\n",
" <td>5.836230</td>\n",
" <td>01:44</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>3.410418</td>\n",
" <td>4.018274</td>\n",
" <td>0.375051</td>\n",
" <td>55.605038</td>\n",
" <td>5.797144</td>\n",
" <td>01:42</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>3.273482</td>\n",
" <td>3.989733</td>\n",
" <td>0.379958</td>\n",
" <td>54.040470</td>\n",
" <td>5.755968</td>\n",
" <td>01:46</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>3.123741</td>\n",
" <td>3.992993</td>\n",
" <td>0.383907</td>\n",
" <td>54.216911</td>\n",
" <td>5.760671</td>\n",
" <td>01:41</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>2.994096</td>\n",
" <td>3.990759</td>\n",
" <td>0.385993</td>\n",
" <td>54.095951</td>\n",
" <td>5.757449</td>\n",
" <td>01:42</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>2.877162</td>\n",
" <td>4.005277</td>\n",
" <td>0.387474</td>\n",
" <td>54.887009</td>\n",
" <td>5.778393</td>\n",
" <td>01:42</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>2.746851</td>\n",
" <td>4.023293</td>\n",
" <td>0.390281</td>\n",
" <td>55.884834</td>\n",
" <td>5.804385</td>\n",
" <td>01:45</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>2.615041</td>\n",
" <td>4.046219</td>\n",
" <td>0.392509</td>\n",
" <td>57.180866</td>\n",
" <td>5.837461</td>\n",
" <td>01:44</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>2.524932</td>\n",
" <td>4.072429</td>\n",
" <td>0.393079</td>\n",
" <td>58.699383</td>\n",
" <td>5.875273</td>\n",
" <td>01:45</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16</td>\n",
" <td>2.434323</td>\n",
" <td>4.087971</td>\n",
" <td>0.392887</td>\n",
" <td>59.618816</td>\n",
" <td>5.897696</td>\n",
" <td>01:42</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17</td>\n",
" <td>2.361165</td>\n",
" <td>4.097982</td>\n",
" <td>0.393126</td>\n",
" <td>60.218670</td>\n",
" <td>5.912139</td>\n",
" <td>01:46</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>2.310880</td>\n",
" <td>4.103701</td>\n",
" <td>0.393512</td>\n",
" <td>60.564026</td>\n",
" <td>5.920389</td>\n",
" <td>01:43</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>2.289599</td>\n",
" <td>4.104567</td>\n",
" <td>0.393593</td>\n",
" <td>60.616467</td>\n",
" <td>5.921638</td>\n",
" <td>01:44</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Better model found at epoch 0 with valid_loss value: 5.488830089569092.\n",
"Better model found at epoch 1 with valid_loss value: 4.980829238891602.\n",
"Better model found at epoch 2 with valid_loss value: 4.636692523956299.\n",
"Better model found at epoch 3 with valid_loss value: 4.476145267486572.\n",
"Better model found at epoch 4 with valid_loss value: 4.28624153137207.\n",
"Better model found at epoch 5 with valid_loss value: 4.174093246459961.\n",
"Better model found at epoch 6 with valid_loss value: 4.089932918548584.\n",
"Better model found at epoch 7 with valid_loss value: 4.045366287231445.\n",
"Better model found at epoch 8 with valid_loss value: 4.018273830413818.\n",
"Better model found at epoch 9 with valid_loss value: 3.9897332191467285.\n"
]
}
],
"source": [
"learn.remove_cb(ActivationStats)\n",
"admin_init(learn.model, scales)\n",
"learn.fit_one_cycle(20, 1e-3, cbs=[WandbCallback(log_model=False, log_preds=False)])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Train - T5 Tokenizer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"T5 SentencePiece BPE Pretrained Tokenizer"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"from transformers import T5TokenizerFast\n",
"\n",
"class T5TokenizerTransform(Transform):\n",
" def __init__(self):\n",
" self.tok = T5TokenizerFast.from_pretrained('t5-small')\n",
" self.pad_token_id = self.tok.pad_token_id\n",
" self.eos_token_id = self.tok.eos_token_id\n",
" self.bos_token_id = self.tok.bos_token_id\n",
" self.unk_token_id = self.tok.unk_token_id\n",
" self.vcb_size = self.tok.vocab_size\n",
" \n",
" def __call__(self, o, **kwargs):\n",
" return LMTensorText(self.tok(o)['input_ids'])\n",
" \n",
" def encodes(self, o):\n",
" return LMTensorText(self.tok(o)['input_ids'])\n",
" \n",
" def decodes(self, o):\n",
" return TitledStr(self.tok.decode(o.numpy()))\n",
" \n",
" @property\n",
" def vocab_size(self): return self.vcb_size"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"t5_tok = T5TokenizerTransform()\n",
"tok = t5_tok"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading cached processed dataset at /home/morgan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-bd8ef24658a2bf53.arrow\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4aba2cec8c844241ae3221804c011c4a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=2461.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Token indices sequence length is longer than the specified maximum sequence length for this model (585 > 512). Running this sequence through the model will result in indexing errors\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"def encode(examples): return {'token_ids' : t5_tok(examples['text'])}\n",
"\n",
"tok_ds = train_ds.map(encode, batched=False)\n",
"val_tok_ds = valid_ds.map(encode, batched=False)\n",
"all_toks = tok_ds['token_ids'] + val_tok_ds['token_ids']\n",
"df['toks'] = all_toks "
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"#Run setup\n",
"n_epochs = 10\n",
"bs = 12\n",
"sl = 512\n",
"n_layers = 6\n",
"pad_id = tok.pad_token_id\n",
"vocab_size = tok.vocab_size"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>text_</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>It has been estimated that average @-@ sized bacteria contain about 2 million proteins per cell ( e.g. E. coli and Staphylococcus aureus ). Smaller bacteria, such as Mycoplasma or spirochetes contain fewer molecules, namely on the order of 50 @,@ 000 to 1 million. By contrast, eukaryotic cells are larger and thus contain much more protein. For instance, yeast cells were estimated to contain about 50 million proteins and human cells on the order of 1 to 3 billion. Note that bacterial genomes encode about 10 times fewer proteins than humans ( e.g. small bacteria &lt;unk&gt; 1 @,@ 000, E. coli : &lt;unk&gt; 4 @,@ 000, yeast : &lt;unk&gt; 6 @,@ 000, human : &lt;unk&gt; 20 @,@ 000 ).&lt;/s&gt; Even though cadmium and its compounds are toxic in certain forms and concentrations, the British Pharmaceutical Codex from 1907 states that cadmium iodide was used as a medication to treat</td>\n",
" <td>has been estimated that average @-@ sized bacteria contain about 2 million proteins per cell ( e.g. E. coli and Staphylococcus aureus ). Smaller bacteria, such as Mycoplasma or spirochetes contain fewer molecules, namely on the order of 50 @,@ 000 to 1 million. By contrast, eukaryotic cells are larger and thus contain much more protein. For instance, yeast cells were estimated to contain about 50 million proteins and human cells on the order of 1 to 3 billion. Note that bacterial genomes encode about 10 times fewer proteins than humans ( e.g. small bacteria &lt;unk&gt; 1 @,@ 000, E. coli : &lt;unk&gt; 4 @,@ 000, yeast : &lt;unk&gt; 6 @,@ 000, human : &lt;unk&gt; 20 @,@ 000 ).&lt;/s&gt; Even though cadmium and its compounds are toxic in certain forms and concentrations, the British Pharmaceutical Codex from 1907 states that cadmium iodide was used as a medication to treat \"</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>, collided with the back of the Renault, knocking the McLaren's front wing off the car. Suffering handling difficulties, Hamilton returned to the pit @-@ lane for a new nose section, and rejoined in 18th place. Räikkönen took second place when he passed Kubica on lap three ; Heidfeld took fourth when he passed Trulli and Kovalainen in separate manoeuvres. Further down the field, Vettel retired from the race on the first lap after twice colliding with other cars ; Button, Sutil and Coulthard pitted to repair early damage.&lt;/s&gt; = = Wheelchair basketball = =&lt;/s&gt; Jardine then ordered his team to move to bodyline positions immediately after Woodfull's injury. Jardine wrote that Larwood had asked for the field, while Larwood said that it was Jardine's decision. The capacity Saturday afternoon crowd viewed this as hitting a man when he was down. Journalist – cricketer Dick Whitington wrote that Jardine's actions</td>\n",
" <td>collided with the back of the Renault, knocking the McLaren's front wing off the car. Suffering handling difficulties, Hamilton returned to the pit @-@ lane for a new nose section, and rejoined in 18th place. Räikkönen took second place when he passed Kubica on lap three ; Heidfeld took fourth when he passed Trulli and Kovalainen in separate manoeuvres. Further down the field, Vettel retired from the race on the first lap after twice colliding with other cars ; Button, Sutil and Coulthard pitted to repair early damage.&lt;/s&gt; = = Wheelchair basketball = =&lt;/s&gt; Jardine then ordered his team to move to bodyline positions immediately after Woodfull's injury. Jardine wrote that Larwood had asked for the field, while Larwood said that it was Jardine's decision. The capacity Saturday afternoon crowd viewed this as hitting a man when he was down. Journalist – cricketer Dick Whitington wrote that Jardine's actions were</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"df['lens'] = df['toks'].apply(len)\n",
"splits = ColSplitter()(df)\n",
"tfms = [attrgetter(\"text\"), t5_tok]\n",
"dsets = Datasets(df, [tfms], splits=splits, dl_type=LMDataLoader)\n",
"dl_kwargs = [{'lens':df['lens'].values[splits[0]]},\n",
" {'val_lens':df['lens'].values[splits[1]]}]\n",
"dls = dsets.dataloaders(bs=bs, seq_len=sl, dl_kwargs=dl_kwargs, shuffle_train=True)\n",
"dls.show_batch(max_n=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### LR Find"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"SuggestedLRs(lr_min=0.05248074531555176, lr_steep=0.43651583790779114)"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.remove_cb(ActivationStats)\n",
"admin_init(learn.model, scales)\n",
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"WANDB_NAME = f'admin_init_t5_20e'\n",
"GROUP = 'admin'\n",
"NOTES = 'T5Tokenizer tokenizer, 20e, 32100 vcb_sz, Early ADMIN testing'\n",
"CONFIG = {}\n",
"TAGS = ['admin','t5']"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmorgan\u001b[0m (use `wandb login --relogin` to force relogin)\n"
]
},
{
"data": {
"text/html": [
"\n",
" Tracking run with wandb version 0.10.19<br/>\n",
" Syncing run <strong style=\"color:#cdcd00\">admin_init_t5_20e</strong> to <a href=\"https://wandb.ai\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
" Project page: <a href=\"https://wandb.ai/fastai_community/transformers_sandbox\" target=\"_blank\">https://wandb.ai/fastai_community/transformers_sandbox</a><br/>\n",
" Run page: <a href=\"https://wandb.ai/fastai_community/transformers_sandbox/runs/2k8d2xge\" target=\"_blank\">https://wandb.ai/fastai_community/transformers_sandbox/runs/2k8d2xge</a><br/>\n",
" Run data is saved locally in <code>/home/morgan/ml/projects/transformers_sandbox/nbs/exploration/wandb/run-20210221_220517-2k8d2xge</code><br/><br/>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<h1>Run(2k8d2xge)</h1><iframe src=\"https://wandb.ai/fastai_community/transformers_sandbox/runs/2k8d2xge\" style=\"border:none;width:100%;height:400px\"></iframe>"
],
"text/plain": [
"<wandb.sdk.wandb_run.Run at 0x7f58280433a0>"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#hide_output\n",
"wandb.init(reinit=True, project=\"transformers_sandbox\", entity=\"fastai_community\", \n",
" name=WANDB_NAME, group=GROUP, notes=NOTES, tags=TAGS, config=CONFIG)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
" </style>\n",
" <progress value='0' class='' max='20' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 0.00% [0/20 00:00<00:00]\n",
" </div>\n",
" \n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>perplexity</th>\n",
" <th>bpc</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table><p>\n",
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
" </style>\n",
" <progress value='8' class='' max='466' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 1.72% [8/466 00:01<01:39 9.5600]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.remove_cb(ActivationStats)\n",
"admin_init(learn.model, scales)\n",
"learn.fit_one_cycle(20, 1e-3, cbs=[WandbCallback(log_model=False, log_preds=False)])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Converted 00_core.ipynb.\n",
"Converted 01_layers.ipynb.\n",
"Converted 02_attention.ipynb.\n",
"Converted 03_models.transformer.ipynb.\n",
"Converted 04a_models.reformer.ipynb.\n",
"Converted 04x_models.xtransformer.ipynb.\n",
"Converted 05_tokenizers.ipynb.\n",
"Converted 06_data.ipynb.\n",
"Converted 07_metrics.ipynb.\n",
"Converted 08_optimizers.ipynb.\n",
"Converted 09_tracking.ipynb.\n",
"Converted 10_config.ipynb.\n",
"Converted index.ipynb.\n"
]
}
],
"source": [
"#hide\n",
"from nbdev.export import notebook2script; notebook2script()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment