Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save bentrevett/4383db9bc56a1dd70b7698d5a6f8edf2 to your computer and use it in GitHub Desktop.
Save bentrevett/4383db9bc56a1dd70b7698d5a6f8edf2 to your computer and use it in GitHub Desktop.
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import torch.nn.functional as F\n",
"\n",
"from torchtext.datasets import TranslationDataset, Multi30k\n",
"from torchtext.data import Field, BucketIterator\n",
"\n",
"import spacy\n",
"\n",
"import random\n",
"import math\n",
"import time"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"SEED = 1234\n",
"\n",
"random.seed(SEED)\n",
"torch.manual_seed(SEED)\n",
"torch.backends.cudnn.deterministic = True"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"spacy_de = spacy.load('de')\n",
"spacy_en = spacy.load('en')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def tokenize_de(text):\n",
" \"\"\"\n",
" Tokenizes German text from a string into a list of strings\n",
" \"\"\"\n",
" return [tok.text for tok in spacy_de.tokenizer(text)]\n",
"\n",
"def tokenize_en(text):\n",
" \"\"\"\n",
" Tokenizes English text from a string into a list of strings\n",
" \"\"\"\n",
" return [tok.text for tok in spacy_en.tokenizer(text)]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"SRC = Field(tokenize = tokenize_de, \n",
" init_token = '<sos>', \n",
" eos_token = '<eos>', \n",
" lower = True)\n",
"\n",
"TRG = Field(tokenize = tokenize_en, \n",
" init_token = '<sos>', \n",
" eos_token = '<eos>', \n",
" lower = True)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'), \n",
" fields = (SRC, TRG))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"SRC.build_vocab(train_data, min_freq = 2)\n",
"TRG.build_vocab(train_data, min_freq = 2)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"BATCH_SIZE = 128\n",
"\n",
"train_iterator, valid_iterator, test_iterator = BucketIterator.splits(\n",
" (train_data, valid_data, test_data), \n",
" batch_size = BATCH_SIZE,\n",
" device = device)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"class Encoder(nn.Module):\n",
" def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):\n",
" super().__init__()\n",
" \n",
" self.embedding = nn.Embedding(input_dim, emb_dim)\n",
" \n",
" self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True)\n",
" \n",
" self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)\n",
" \n",
" self.dropout = nn.Dropout(dropout)\n",
" \n",
" def forward(self, src):\n",
" \n",
" #src = [src sent len, batch size]\n",
" \n",
" embedded = self.dropout(self.embedding(src))\n",
" \n",
" #embedded = [src sent len, batch size, emb dim]\n",
" \n",
" outputs, hidden = self.rnn(embedded)\n",
" \n",
" #outputs = [src sent len, batch size, hid dim * num directions]\n",
" #hidden = [n layers * num directions, batch size, hid dim]\n",
" \n",
" #hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]\n",
" #outputs are always from the last layer\n",
" \n",
" #hidden [-2, :, : ] is the last of the forwards RNN \n",
" #hidden [-1, :, : ] is the last of the backwards RNN\n",
" \n",
" #initial decoder hidden is final hidden state of the forwards and backwards \n",
" # encoder RNNs fed through a linear layer\n",
" hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))\n",
" \n",
" #outputs = [src sent len, batch size, enc hid dim * 2]\n",
" #hidden = [batch size, dec hid dim]\n",
" \n",
" return outputs, hidden"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"class Attention(nn.Module):\n",
" def __init__(self, enc_hid_dim, dec_hid_dim):\n",
" super().__init__()\n",
" \n",
" self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)\n",
" self.v = nn.Parameter(torch.rand(dec_hid_dim))\n",
" \n",
" def forward(self, hidden, encoder_outputs):\n",
" \n",
" #hidden = [batch size, dec hid dim]\n",
" #encoder_outputs = [src sent len, batch size, enc hid dim * 2]\n",
" \n",
" batch_size = encoder_outputs.shape[1]\n",
" src_len = encoder_outputs.shape[0]\n",
" \n",
" #repeat encoder hidden state src_len times\n",
" hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)\n",
" \n",
" encoder_outputs = encoder_outputs.permute(1, 0, 2)\n",
" \n",
" #hidden = [batch size, src sent len, dec hid dim]\n",
" #encoder_outputs = [batch size, src sent len, enc hid dim * 2]\n",
" \n",
" energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) \n",
" \n",
" #energy = [batch size, src sent len, dec hid dim]\n",
" \n",
" energy = energy.permute(0, 2, 1)\n",
" \n",
" #energy = [batch size, dec hid dim, src sent len]\n",
" \n",
" #v = [dec hid dim]\n",
" \n",
" v = self.v.repeat(batch_size, 1).unsqueeze(1)\n",
" \n",
" #v = [batch size, 1, dec hid dim]\n",
" \n",
" attention = torch.bmm(v, energy).squeeze(1)\n",
" \n",
" #attention= [batch size, src len]\n",
" \n",
" return F.softmax(attention, dim=1)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"class Decoder(nn.Module):\n",
" def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):\n",
" super().__init__()\n",
"\n",
" self.output_dim = output_dim\n",
" self.attention = attention\n",
" \n",
" self.embedding = nn.Embedding(output_dim, emb_dim)\n",
" \n",
" self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)\n",
" \n",
" self.out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)\n",
" \n",
" self.dropout = nn.Dropout(dropout)\n",
" \n",
" def forward(self, input, hidden, encoder_outputs):\n",
" \n",
" #input = [batch size]\n",
" #hidden = [batch size, dec hid dim]\n",
" #encoder_outputs = [src sent len, batch size, enc hid dim * 2]\n",
" \n",
" input = input.unsqueeze(0)\n",
" \n",
" #input = [1, batch size]\n",
" \n",
" embedded = self.dropout(self.embedding(input))\n",
" \n",
" #embedded = [1, batch size, emb dim]\n",
" \n",
" a = self.attention(hidden, encoder_outputs)\n",
" \n",
" #a = [batch size, src len]\n",
" \n",
" a = a.unsqueeze(1)\n",
" \n",
" #a = [batch size, 1, src len]\n",
" \n",
" encoder_outputs = encoder_outputs.permute(1, 0, 2)\n",
" \n",
" #encoder_outputs = [batch size, src sent len, enc hid dim * 2]\n",
" \n",
" weighted = torch.bmm(a, encoder_outputs)\n",
" \n",
" #weighted = [batch size, 1, enc hid dim * 2]\n",
" \n",
" weighted = weighted.permute(1, 0, 2)\n",
" \n",
" #weighted = [1, batch size, enc hid dim * 2]\n",
" \n",
" rnn_input = torch.cat((embedded, weighted), dim = 2)\n",
" \n",
" #rnn_input = [1, batch size, (enc hid dim * 2) + emb dim]\n",
" \n",
" output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))\n",
" \n",
" #output = [sent len, batch size, dec hid dim * n directions]\n",
" #hidden = [n layers * n directions, batch size, dec hid dim]\n",
" \n",
" #sent len, n layers and n directions will always be 1 in this decoder, therefore:\n",
" #output = [1, batch size, dec hid dim]\n",
" #hidden = [1, batch size, dec hid dim]\n",
" #this also means that output == hidden\n",
" assert (output == hidden).all()\n",
" \n",
" embedded = embedded.squeeze(0)\n",
" output = output.squeeze(0)\n",
" weighted = weighted.squeeze(0)\n",
" \n",
" output = self.out(torch.cat((output, weighted, embedded), dim = 1))\n",
" \n",
" #output = [bsz, output dim]\n",
" \n",
" return output, hidden.squeeze(0)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"class Seq2Seq(nn.Module):\n",
" def __init__(self, encoder, decoder, device):\n",
" super().__init__()\n",
" \n",
" self.encoder = encoder\n",
" self.decoder = decoder\n",
" self.device = device\n",
" \n",
" def forward(self, src, trg, teacher_forcing_ratio = 0.5):\n",
" \n",
" #src = [src sent len, batch size]\n",
" #trg = [trg sent len, batch size]\n",
" #teacher_forcing_ratio is probability to use teacher forcing\n",
" #e.g. if teacher_forcing_ratio is 0.75 we use teacher forcing 75% of the time\n",
" \n",
" batch_size = src.shape[1]\n",
" max_len = trg.shape[0]\n",
" trg_vocab_size = self.decoder.output_dim\n",
" \n",
" #tensor to store decoder outputs\n",
" outputs = torch.zeros(max_len, batch_size, trg_vocab_size).to(self.device)\n",
" \n",
" #encoder_outputs is all hidden states of the input sequence, back and forwards\n",
" #hidden is the final forward and backward hidden states, passed through a linear layer\n",
" encoder_outputs, hidden = self.encoder(src)\n",
" \n",
" #first input to the decoder is the <sos> tokens\n",
" input = trg[0,:]\n",
" \n",
" for t in range(1, max_len):\n",
" \n",
" #insert input token embedding, previous hidden state and all encoder hidden states\n",
" #receive output tensor (predictions) and new hidden state\n",
" output, hidden = self.decoder(input, hidden, encoder_outputs)\n",
" \n",
" #place predictions in a tensor holding predictions for each token\n",
" outputs[t] = output\n",
" \n",
" #decide if we are going to use teacher forcing or not\n",
" teacher_force = random.random() < teacher_forcing_ratio\n",
" \n",
" #get the highest predicted token from our predictions\n",
" top1 = output.argmax(1) \n",
" \n",
" #if teacher forcing, use actual next token as next input\n",
" #if not, use predicted token\n",
" input = trg[t] if teacher_force else top1\n",
"\n",
" return outputs"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"INPUT_DIM = len(SRC.vocab)\n",
"OUTPUT_DIM = len(TRG.vocab)\n",
"ENC_EMB_DIM = 256\n",
"DEC_EMB_DIM = 256\n",
"ENC_HID_DIM = 512\n",
"DEC_HID_DIM = 512\n",
"ENC_DROPOUT = 0.5\n",
"DEC_DROPOUT = 0.5\n",
"\n",
"attn = Attention(ENC_HID_DIM, DEC_HID_DIM)\n",
"enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)\n",
"dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)\n",
"\n",
"model = Seq2Seq(enc, dec, device).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Seq2Seq(\n",
" (encoder): Encoder(\n",
" (embedding): Embedding(7855, 256)\n",
" (rnn): GRU(256, 512, bidirectional=True)\n",
" (fc): Linear(in_features=1024, out_features=512, bias=True)\n",
" (dropout): Dropout(p=0.5, inplace=False)\n",
" )\n",
" (decoder): Decoder(\n",
" (attention): Attention(\n",
" (attn): Linear(in_features=1536, out_features=512, bias=True)\n",
" )\n",
" (embedding): Embedding(5893, 256)\n",
" (rnn): GRU(1280, 512)\n",
" (out): Linear(in_features=1792, out_features=5893, bias=True)\n",
" (dropout): Dropout(p=0.5, inplace=False)\n",
" )\n",
")"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def init_weights(m):\n",
" for name, param in m.named_parameters():\n",
" if 'weight' in name:\n",
" nn.init.normal_(param.data, mean=0, std=0.01)\n",
" else:\n",
" nn.init.constant_(param.data, 0)\n",
" \n",
"model.apply(init_weights)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Calculate the number of parameters. We get an increase of almost 50% in the amount of parameters from the last model. "
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The model has 20,518,917 trainable parameters\n"
]
}
],
"source": [
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"print(f'The model has {count_parameters(model):,} trainable parameters')"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"optimizer = optim.Adam(model.parameters())"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"PAD_IDX = TRG.vocab.stoi['<pad>']\n",
"\n",
"criterion = nn.CrossEntropyLoss(ignore_index = PAD_IDX, reduction='none')"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"def train(model, iterator, optimizer, criterion, clip):\n",
" \n",
" model.train()\n",
" \n",
" epoch_loss_per_sent = 0\n",
" epoch_loss_per_tok = 0\n",
" \n",
" for i, batch in enumerate(iterator):\n",
" \n",
" src = batch.src\n",
" trg = batch.trg\n",
" \n",
" trg_sent_len = trg.shape[0]\n",
" \n",
" optimizer.zero_grad()\n",
" \n",
" output = model(src, trg)\n",
" \n",
" #trg = [trg sent len, batch size]\n",
" #output = [trg sent len, batch size, output dim]\n",
" \n",
" output = output[1:].view(-1, output.shape[-1])\n",
" trg = trg[1:].view(-1)\n",
" \n",
" #trg = [(trg sent len - 1) * batch size]\n",
" #output = [(trg sent len - 1) * batch size, output dim]\n",
" \n",
" loss = criterion(output, trg)\n",
" \n",
" #reshape to [trg sent len - 1, batch size]\n",
" loss = loss.view(trg_sent_len-1, -1)\n",
" \n",
" with torch.no_grad():\n",
" loss_per_tok = loss.mean()\n",
" \n",
" #sum loss across trg sent len dimension\n",
" loss = loss.sum(0)\n",
" \n",
" #average across batch to get average loss per sentence\n",
" loss = loss.mean(0)\n",
" \n",
" loss.backward()\n",
" \n",
" torch.nn.utils.clip_grad_norm_(model.parameters(), clip)\n",
" \n",
" optimizer.step()\n",
" \n",
" epoch_loss_per_sent += loss.item()\n",
" epoch_loss_per_tok += loss_per_tok.item()\n",
" \n",
" return epoch_loss_per_sent / len(iterator), epoch_loss_per_tok / len(iterator)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"def evaluate(model, iterator, criterion):\n",
" \n",
" model.eval()\n",
" \n",
" epoch_loss_per_sent = 0\n",
" epoch_loss_per_tok = 0\n",
" \n",
" with torch.no_grad():\n",
" \n",
" for i, batch in enumerate(iterator):\n",
"\n",
" src = batch.src\n",
" trg = batch.trg\n",
"\n",
" trg_sent_len = trg.shape[0]\n",
" \n",
" output = model(src, trg, 0) #turn off teacher forcing\n",
"\n",
" #trg = [trg sent len, batch size]\n",
" #output = [trg sent len, batch size, output dim]\n",
"\n",
" output = output[1:].view(-1, output.shape[-1])\n",
" trg = trg[1:].view(-1)\n",
"\n",
" #trg = [(trg sent len - 1) * batch size]\n",
" #output = [(trg sent len - 1) * batch size, output dim]\n",
"\n",
" loss = criterion(output, trg)\n",
"\n",
" #reshape to [trg sent len - 1, batch size]\n",
" loss = loss.view(trg_sent_len - 1, -1)\n",
"\n",
" with torch.no_grad():\n",
" loss_per_tok = loss.mean()\n",
" \n",
" #sum loss across trg sent len dimension\n",
" loss = loss.sum(0)\n",
"\n",
" #average across batch to get average loss per sentence\n",
" loss = loss.mean(0)\n",
" \n",
" epoch_loss_per_sent += loss.item()\n",
" epoch_loss_per_tok += loss_per_tok.item()\n",
" \n",
" return epoch_loss_per_sent / len(iterator), epoch_loss_per_tok / len(iterator)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"def epoch_time(start_time, end_time):\n",
" elapsed_time = end_time - start_time\n",
" elapsed_mins = int(elapsed_time / 60)\n",
" elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
" return elapsed_mins, elapsed_secs"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 01 | Time: 0m 57s\n",
"\tTrain Loss Sen: 70.891 | Train PPL Sen: 6128921342779803030655955107840.000\n",
"\t Val. Loss Sen: 69.310 | Val. PPL Sen: 1262313217421856109485195526144.000\n",
"\tTrain Loss Tok: 2.450 | Train PPL Tok: 11.587\n",
"\t Val. Loss Tok: 3.779 | Val. PPL Tok: 43.750\n",
"Epoch: 02 | Time: 0m 57s\n",
"\tTrain Loss Sen: 57.392 | Train PPL Sen: 8414430320705913562857472.000\n",
"\t Val. Loss Sen: 61.005 | Val. PPL Sen: 312055644358700137357246464.000\n",
"\tTrain Loss Tok: 1.975 | Train PPL Tok: 7.208\n",
"\t Val. Loss Tok: 3.323 | Val. PPL Tok: 27.755\n",
"Epoch: 03 | Time: 0m 55s\n",
"\tTrain Loss Sen: 47.220 | Train PPL Sen: 321768993183206539264.000\n",
"\t Val. Loss Sen: 53.018 | Val. PPL Sen: 105976412642502959431680.000\n",
"\tTrain Loss Tok: 1.637 | Train PPL Tok: 5.138\n",
"\t Val. Loss Tok: 2.846 | Val. PPL Tok: 17.214\n",
"Epoch: 04 | Time: 0m 55s\n",
"\tTrain Loss Sen: 39.678 | Train PPL Sen: 170582578988885888.000\n",
"\t Val. Loss Sen: 49.045 | Val. PPL Sen: 1995139708260794499072.000\n",
"\tTrain Loss Tok: 1.371 | Train PPL Tok: 3.941\n",
"\t Val. Loss Tok: 2.624 | Val. PPL Tok: 13.797\n",
"Epoch: 05 | Time: 0m 55s\n",
"\tTrain Loss Sen: 34.348 | Train PPL Sen: 826281990935799.375\n",
"\t Val. Loss Sen: 47.340 | Val. PPL Sen: 362760781805227409408.000\n",
"\tTrain Loss Tok: 1.195 | Train PPL Tok: 3.305\n",
"\t Val. Loss Tok: 2.519 | Val. PPL Tok: 12.422\n",
"Epoch: 06 | Time: 0m 57s\n",
"\tTrain Loss Sen: 30.415 | Train PPL Sen: 16177867529313.486\n",
"\t Val. Loss Sen: 47.717 | Val. PPL Sen: 528660401949538320384.000\n",
"\tTrain Loss Tok: 1.046 | Train PPL Tok: 2.848\n",
"\t Val. Loss Tok: 2.538 | Val. PPL Tok: 12.660\n",
"Epoch: 07 | Time: 0m 57s\n",
"\tTrain Loss Sen: 27.144 | Train PPL Sen: 614528985856.751\n",
"\t Val. Loss Sen: 46.614 | Val. PPL Sen: 175556969630909300736.000\n",
"\tTrain Loss Tok: 0.942 | Train PPL Tok: 2.565\n",
"\t Val. Loss Tok: 2.477 | Val. PPL Tok: 11.909\n",
"Epoch: 08 | Time: 0m 57s\n",
"\tTrain Loss Sen: 24.256 | Train PPL Sen: 34200793792.566\n",
"\t Val. Loss Sen: 48.250 | Val. PPL Sen: 900993147005708402688.000\n",
"\tTrain Loss Tok: 0.843 | Train PPL Tok: 2.324\n",
"\t Val. Loss Tok: 2.562 | Val. PPL Tok: 12.958\n",
"Epoch: 09 | Time: 0m 57s\n",
"\tTrain Loss Sen: 22.056 | Train PPL Sen: 3791124842.387\n",
"\t Val. Loss Sen: 48.090 | Val. PPL Sen: 767730435715759079424.000\n",
"\tTrain Loss Tok: 0.760 | Train PPL Tok: 2.139\n",
"\t Val. Loss Tok: 2.551 | Val. PPL Tok: 12.820\n",
"Epoch: 10 | Time: 0m 57s\n",
"\tTrain Loss Sen: 20.619 | Train PPL Sen: 900792843.540\n",
"\t Val. Loss Sen: 49.099 | Val. PPL Sen: 2105042898043442823168.000\n",
"\tTrain Loss Tok: 0.714 | Train PPL Tok: 2.042\n",
"\t Val. Loss Tok: 2.619 | Val. PPL Tok: 13.718\n"
]
}
],
"source": [
"N_EPOCHS = 10\n",
"CLIP = 1\n",
"\n",
"best_valid_loss = float('inf')\n",
"\n",
"for epoch in range(N_EPOCHS):\n",
" \n",
" start_time = time.time()\n",
" \n",
" train_loss_per_sent, train_loss_per_tok = train(model, train_iterator, optimizer, criterion, CLIP)\n",
" valid_loss_per_sent, valid_loss_per_tok = evaluate(model, valid_iterator, criterion)\n",
" \n",
" end_time = time.time()\n",
" \n",
" epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
" \n",
" if valid_loss_per_sent < best_valid_loss:\n",
" best_valid_loss = valid_loss_per_sent\n",
" torch.save(model.state_dict(), 'tut3-model.pt')\n",
" \n",
" print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')\n",
" print(f'\\tTrain Loss Sen: {train_loss_per_sent:.3f} | Train PPL Sen: {math.exp(train_loss_per_sent):7.3f}')\n",
" print(f'\\t Val. Loss Sen: {valid_loss_per_sent:.3f} | Val. PPL Sen: {math.exp(valid_loss_per_sent):7.3f}')\n",
" print(f'\\tTrain Loss Tok: {train_loss_per_tok:.3f} | Train PPL Tok: {math.exp(train_loss_per_tok):7.3f}')\n",
" print(f'\\t Val. Loss Tok: {valid_loss_per_tok:.3f} | Val. PPL Tok: {math.exp(valid_loss_per_tok):7.3f}')"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"| Test Loss: 46.942 | Test PPL: 243490446760497512448.000 |\n",
"| Test Loss: 2.529 | Test PPL: 12.541 |\n"
]
}
],
"source": [
"model.load_state_dict(torch.load('tut3-model.pt'))\n",
"\n",
"test_loss_per_sent, test_loss_per_tok = evaluate(model, test_iterator, criterion)\n",
"\n",
"print(f'| Test Loss: {test_loss_per_sent:.3f} | Test PPL: {math.exp(test_loss_per_sent):7.3f} |')\n",
"print(f'| Test Loss: {test_loss_per_tok:.3f} | Test PPL: {math.exp(test_loss_per_tok):7.3f} |')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment