Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save bentrevett/6986196132f253cb48de88e90605901f to your computer and use it in GitHub Desktop.
Save bentrevett/6986196132f253cb48de88e90605901f to your computer and use it in GitHub Desktop.
English Setentence Seq2Seq with Attention Reconstruction
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "English Setentence Seq2Seq with Attention Reconstruction",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/bentrevett/6986196132f253cb48de88e90605901f/english-setentence-seq2seq-with-attention-reconstruction.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "HEyZEOBHDzDq",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 453
},
"outputId": "91338372-635c-4fb1-c95e-202abf56b68d"
},
"source": [
"!pip install torchtext --upgrade\n",
"!python -m spacy download en\n",
"!python -m spacy download de"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already up-to-date: torchtext in /usr/local/lib/python3.6/dist-packages (0.4.0)\n",
"Requirement already satisfied, skipping upgrade: requests in /usr/local/lib/python3.6/dist-packages (from torchtext) (2.21.0)\n",
"Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from torchtext) (1.12.0)\n",
"Requirement already satisfied, skipping upgrade: torch in /usr/local/lib/python3.6/dist-packages (from torchtext) (1.3.1)\n",
"Requirement already satisfied, skipping upgrade: tqdm in /usr/local/lib/python3.6/dist-packages (from torchtext) (4.28.1)\n",
"Requirement already satisfied, skipping upgrade: numpy in /usr/local/lib/python3.6/dist-packages (from torchtext) (1.17.4)\n",
"Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (2019.11.28)\n",
"Requirement already satisfied, skipping upgrade: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (3.0.4)\n",
"Requirement already satisfied, skipping upgrade: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (1.24.3)\n",
"Requirement already satisfied, skipping upgrade: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (2.8)\n",
"Requirement already satisfied: en_core_web_sm==2.1.0 from https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.1.0/en_core_web_sm-2.1.0.tar.gz#egg=en_core_web_sm==2.1.0 in /usr/local/lib/python3.6/dist-packages (2.1.0)\n",
"\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n",
"You can now load the model via spacy.load('en_core_web_sm')\n",
"\u001b[38;5;2m✔ Linking successful\u001b[0m\n",
"/usr/local/lib/python3.6/dist-packages/en_core_web_sm -->\n",
"/usr/local/lib/python3.6/dist-packages/spacy/data/en\n",
"You can now load the model via spacy.load('en')\n",
"Requirement already satisfied: de_core_news_sm==2.1.0 from https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-2.1.0/de_core_news_sm-2.1.0.tar.gz#egg=de_core_news_sm==2.1.0 in /usr/local/lib/python3.6/dist-packages (2.1.0)\n",
"\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n",
"You can now load the model via spacy.load('de_core_news_sm')\n",
"\u001b[38;5;2m✔ Linking successful\u001b[0m\n",
"/usr/local/lib/python3.6/dist-packages/de_core_news_sm -->\n",
"/usr/local/lib/python3.6/dist-packages/spacy/data/de\n",
"You can now load the model via spacy.load('de')\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "bdcqSLmkE29b",
"colab_type": "code",
"colab": {}
},
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import torch.nn.functional as F\n",
"\n",
"import torchtext\n",
"from torchtext.datasets import TranslationDataset, Multi30k\n",
"from torchtext.data import Field, BucketIterator\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import spacy\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"\n",
"import random\n",
"import math\n",
"import time"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Bcbni4e_DtZc",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
},
"outputId": "2c52d3d6-7599-4609-f82a-e137bf9eb513"
},
"source": [
"print(torch.__version__)\n",
"print(torchtext.__version__)"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": [
"1.3.1\n",
"0.4.0\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "9B-IFbWkDw4z",
"colab_type": "code",
"colab": {}
},
"source": [
"SEED = 1234\n",
"\n",
"random.seed(SEED)\n",
"np.random.seed(SEED)\n",
"torch.manual_seed(SEED)\n",
"torch.cuda.manual_seed(SEED)\n",
"torch.backends.cudnn.deterministic = True"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "57f0vB7SEEj5",
"colab_type": "code",
"colab": {}
},
"source": [
"spacy_de = spacy.load('de')\n",
"spacy_en = spacy.load('en')"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "LJO6Wh19EGiE",
"colab_type": "code",
"colab": {}
},
"source": [
"def tokenize_de(text):\n",
" return [tok.text for tok in spacy_de.tokenizer(text)]\n",
"\n",
"def tokenize_en(text):\n",
" return [tok.text for tok in spacy_en.tokenizer(text)]"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "8bJiommNEI_C",
"colab_type": "code",
"colab": {}
},
"source": [
"SRC = Field(tokenize = tokenize_de, \n",
" init_token = '<sos>', \n",
" eos_token = '<eos>', \n",
" lower = True)\n",
"\n",
"TRG = Field(tokenize = tokenize_en, \n",
" init_token = '<sos>', \n",
" eos_token = '<eos>', \n",
" lower = True)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "FEVOYxoZELcB",
"colab_type": "code",
"colab": {}
},
"source": [
"train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'), \n",
" fields = (SRC, TRG))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "h_j8u_8uEUNX",
"colab_type": "code",
"colab": {}
},
"source": [
"SRC.build_vocab(train_data, min_freq = 2)\n",
"TRG.build_vocab(train_data, min_freq = 2)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "09z4CZM5FCgI",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "9e6e7bc4-ffa7-41c7-ce35-2ad9a5fa978f"
},
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"print(device)"
],
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"text": [
"cuda\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "q0L-rxa8FYE1",
"colab_type": "code",
"colab": {}
},
"source": [
"BATCH_SIZE = 128\n",
"\n",
"train_iterator, valid_iterator, test_iterator = BucketIterator.splits(\n",
" (train_data, valid_data, test_data), \n",
" batch_size = BATCH_SIZE,\n",
" device = device)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "OjVNO6HsFF9Y",
"colab_type": "code",
"colab": {}
},
"source": [
"class Encoder(nn.Module):\n",
" def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):\n",
" super().__init__()\n",
" \n",
" self.embedding = nn.Embedding(input_dim, emb_dim)\n",
" \n",
" self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True)\n",
" \n",
" self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)\n",
" \n",
" self.dropout = nn.Dropout(dropout)\n",
" \n",
" def forward(self, src):\n",
" \n",
" embedded = self.dropout(self.embedding(src))\n",
" \n",
" outputs, hidden = self.rnn(embedded)\n",
" \n",
" hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))\n",
" \n",
" return outputs, hidden"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "JaavhOMGFcnp",
"colab_type": "code",
"colab": {}
},
"source": [
"class Attention(nn.Module):\n",
" def __init__(self, enc_hid_dim, dec_hid_dim):\n",
" super().__init__()\n",
" \n",
" self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)\n",
" self.v = nn.Parameter(torch.rand(dec_hid_dim))\n",
" \n",
" def forward(self, hidden, encoder_outputs):\n",
" \n",
" batch_size = encoder_outputs.shape[1]\n",
" src_len = encoder_outputs.shape[0]\n",
" hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)\n",
" \n",
" encoder_outputs = encoder_outputs.permute(1, 0, 2)\n",
" \n",
" energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) \n",
" \n",
" energy = energy.permute(0, 2, 1)\n",
" \n",
" v = self.v.repeat(batch_size, 1).unsqueeze(1)\n",
" \n",
" attention = torch.bmm(v, energy).squeeze(1)\n",
" \n",
" return F.softmax(attention, dim=1)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "qvG_Niv9Fk0v",
"colab_type": "code",
"colab": {}
},
"source": [
"class Decoder(nn.Module):\n",
" def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):\n",
" super().__init__()\n",
"\n",
" self.output_dim = output_dim\n",
" self.attention = attention\n",
" \n",
" self.embedding = nn.Embedding(output_dim, emb_dim)\n",
" \n",
" self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)\n",
" \n",
" self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)\n",
" \n",
" self.dropout = nn.Dropout(dropout)\n",
" \n",
" def forward(self, input, hidden, encoder_outputs):\n",
" \n",
" input = input.unsqueeze(0)\n",
" \n",
" embedded = self.dropout(self.embedding(input))\n",
" \n",
" a = self.attention(hidden, encoder_outputs)\n",
" \n",
" a = a.unsqueeze(1)\n",
" \n",
" encoder_outputs = encoder_outputs.permute(1, 0, 2)\n",
" \n",
" weighted = torch.bmm(a, encoder_outputs)\n",
"\n",
" weighted = weighted.permute(1, 0, 2)\n",
"\n",
" rnn_input = torch.cat((embedded, weighted), dim = 2)\n",
"\n",
" output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))\n",
" \n",
" assert (output == hidden).all()\n",
" \n",
" embedded = embedded.squeeze(0)\n",
" output = output.squeeze(0)\n",
" weighted = weighted.squeeze(0)\n",
" \n",
" prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))\n",
"\n",
" return prediction, hidden.squeeze(0)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "12Ge__ynFrAK",
"colab_type": "code",
"colab": {}
},
"source": [
"class Seq2Seq(nn.Module):\n",
" def __init__(self, encoder, decoder, device):\n",
" super().__init__()\n",
" \n",
" self.encoder = encoder\n",
" self.decoder = decoder\n",
" self.device = device\n",
" \n",
" def forward(self, src, trg, teacher_forcing_ratio = 0.5):\n",
" \n",
" batch_size = src.shape[1]\n",
" trg_len = trg.shape[0]\n",
" trg_vocab_size = self.decoder.output_dim\n",
" \n",
" outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)\n",
" \n",
" encoder_outputs, hidden = self.encoder(src)\n",
" \n",
" input = trg[0,:]\n",
" \n",
" for t in range(1, trg_len):\n",
" \n",
" output, hidden = self.decoder(input, hidden, encoder_outputs)\n",
" \n",
" outputs[t] = output\n",
" \n",
" teacher_force = random.random() < teacher_forcing_ratio\n",
" \n",
" top1 = output.argmax(1) \n",
"\n",
" input = trg[t] if teacher_force else top1\n",
"\n",
" return outputs"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "BJ2Fbin6Fvkd",
"colab_type": "code",
"colab": {}
},
"source": [
"INPUT_DIM = len(TRG.vocab)\n",
"OUTPUT_DIM = len(TRG.vocab)\n",
"ENC_EMB_DIM = 256\n",
"DEC_EMB_DIM = 256\n",
"ENC_HID_DIM = 512\n",
"DEC_HID_DIM = 512\n",
"ENC_DROPOUT = 0.5\n",
"DEC_DROPOUT = 0.5\n",
"\n",
"attn = Attention(ENC_HID_DIM, DEC_HID_DIM)\n",
"enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)\n",
"dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)\n",
"\n",
"model = Seq2Seq(enc, dec, device).to(device)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "s_JqsrVcF0MX",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 312
},
"outputId": "cd35e228-7f11-48d8-a6ef-3a0534973a6f"
},
"source": [
"def init_weights(m):\n",
" for name, param in m.named_parameters():\n",
" if 'weight' in name:\n",
" nn.init.normal_(param.data, mean=0, std=0.01)\n",
" else:\n",
" nn.init.constant_(param.data, 0)\n",
" \n",
"model.apply(init_weights)"
],
"execution_count": 17,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Seq2Seq(\n",
" (encoder): Encoder(\n",
" (embedding): Embedding(5893, 256)\n",
" (rnn): GRU(256, 512, bidirectional=True)\n",
" (fc): Linear(in_features=1024, out_features=512, bias=True)\n",
" (dropout): Dropout(p=0.5, inplace=False)\n",
" )\n",
" (decoder): Decoder(\n",
" (attention): Attention(\n",
" (attn): Linear(in_features=1536, out_features=512, bias=True)\n",
" )\n",
" (embedding): Embedding(5893, 256)\n",
" (rnn): GRU(1280, 512)\n",
" (fc_out): Linear(in_features=1792, out_features=5893, bias=True)\n",
" (dropout): Dropout(p=0.5, inplace=False)\n",
" )\n",
")"
]
},
"metadata": {
"tags": []
},
"execution_count": 17
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "DtumAJbeF2fW",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "1a88d670-1ac9-47c8-db69-387a1fd5500a"
},
"source": [
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"print(f'The model has {count_parameters(model):,} trainable parameters')"
],
"execution_count": 18,
"outputs": [
{
"output_type": "stream",
"text": [
"The model has 20,016,645 trainable parameters\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "RJ7vE72GF6v-",
"colab_type": "code",
"colab": {}
},
"source": [
"optimizer = optim.Adam(model.parameters())"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "qP7PwgF_F8uu",
"colab_type": "code",
"colab": {}
},
"source": [
"TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]\n",
"\n",
"criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "1fYif7TLF-zq",
"colab_type": "code",
"colab": {}
},
"source": [
"def train(model, iterator, optimizer, criterion, clip):\n",
" \n",
" model.train()\n",
" \n",
" epoch_loss = 0\n",
" \n",
" for i, batch in enumerate(tqdm(iterator)):\n",
" \n",
" src = batch.trg\n",
" trg = batch.trg\n",
" \n",
" optimizer.zero_grad()\n",
" \n",
" output = model(src, trg)\n",
" \n",
" output_dim = output.shape[-1]\n",
" \n",
" output = output[1:].view(-1, output_dim)\n",
" trg = trg[1:].view(-1)\n",
" \n",
" loss = criterion(output, trg)\n",
" \n",
" loss.backward()\n",
" \n",
" torch.nn.utils.clip_grad_norm_(model.parameters(), clip)\n",
" \n",
" optimizer.step()\n",
" \n",
" epoch_loss += loss.item()\n",
" \n",
" return epoch_loss / len(iterator)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "oCSh8Mp9GEhm",
"colab_type": "code",
"colab": {}
},
"source": [
"def evaluate(model, iterator, criterion):\n",
" \n",
" model.eval()\n",
" \n",
" epoch_loss = 0\n",
" \n",
" with torch.no_grad():\n",
" \n",
" for i, batch in enumerate(tqdm(iterator)):\n",
"\n",
" src = batch.trg\n",
" trg = batch.trg\n",
"\n",
" output = model(src, trg, 0)\n",
"\n",
" output_dim = output.shape[-1]\n",
" \n",
" output = output[1:].view(-1, output_dim)\n",
" trg = trg[1:].view(-1)\n",
"\n",
" loss = criterion(output, trg)\n",
"\n",
" epoch_loss += loss.item()\n",
" \n",
" return epoch_loss / len(iterator)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Hp1VOPuKGHmg",
"colab_type": "code",
"colab": {}
},
"source": [
"def epoch_time(start_time, end_time):\n",
" elapsed_time = end_time - start_time\n",
" elapsed_mins = int(elapsed_time / 60)\n",
" elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
" return elapsed_mins, elapsed_secs"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "5j2XZb52GJxK",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 537
},
"outputId": "9b22d661-c42a-4cd0-e62f-4be449a72ded"
},
"source": [
"N_EPOCHS = 5\n",
"CLIP = 1\n",
"\n",
"best_valid_loss = float('inf')\n",
"\n",
"for epoch in range(N_EPOCHS):\n",
" \n",
" start_time = time.time()\n",
" \n",
" train_loss = train(model, train_iterator, optimizer, criterion, CLIP)\n",
" valid_loss = evaluate(model, valid_iterator, criterion)\n",
" \n",
" end_time = time.time()\n",
" \n",
" epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
" \n",
" if valid_loss < best_valid_loss:\n",
" best_valid_loss = valid_loss\n",
" torch.save(model.state_dict(), 'reconstruction-model.pt')\n",
" \n",
" print()\n",
" print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')\n",
" print(f'\\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')\n",
" print(f'\\t Val. Loss: {valid_loss:.3f} | Val. PPL: {math.exp(valid_loss):7.3f}')"
],
"execution_count": 24,
"outputs": [
{
"output_type": "stream",
"text": [
"100%|██████████| 227/227 [02:00<00:00, 2.00it/s]\n",
"100%|██████████| 8/8 [00:00<00:00, 6.78it/s]\n",
" 0%| | 0/227 [00:00<?, ?it/s]"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"\n",
"Epoch: 01 | Time: 2m 0s\n",
"\tTrain Loss: 5.006 | Train PPL: 149.257\n",
"\t Val. Loss: 4.905 | Val. PPL: 134.938\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"100%|██████████| 227/227 [01:59<00:00, 1.89it/s]\n",
"100%|██████████| 8/8 [00:00<00:00, 6.74it/s]\n",
" 0%| | 0/227 [00:00<?, ?it/s]"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"\n",
"Epoch: 02 | Time: 2m 0s\n",
"\tTrain Loss: 4.079 | Train PPL: 59.097\n",
"\t Val. Loss: 4.674 | Val. PPL: 107.150\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"100%|██████████| 227/227 [01:59<00:00, 1.93it/s]\n",
"100%|██████████| 8/8 [00:00<00:00, 6.76it/s]\n",
" 0%| | 0/227 [00:00<?, ?it/s]"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"\n",
"Epoch: 03 | Time: 2m 0s\n",
"\tTrain Loss: 2.399 | Train PPL: 11.016\n",
"\t Val. Loss: 1.395 | Val. PPL: 4.035\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"100%|██████████| 227/227 [01:59<00:00, 1.89it/s]\n",
"100%|██████████| 8/8 [00:00<00:00, 6.72it/s]\n",
" 0%| | 0/227 [00:00<?, ?it/s]"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"\n",
"Epoch: 04 | Time: 2m 0s\n",
"\tTrain Loss: 1.064 | Train PPL: 2.897\n",
"\t Val. Loss: 0.621 | Val. PPL: 1.860\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"100%|██████████| 227/227 [01:58<00:00, 1.76it/s]\n",
"100%|██████████| 8/8 [00:00<00:00, 6.70it/s]"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"\n",
"Epoch: 05 | Time: 1m 59s\n",
"\tTrain Loss: 0.456 | Train PPL: 1.577\n",
"\t Val. Loss: 0.320 | Val. PPL: 1.377\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "QVJieGnZGPeD",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
},
"outputId": "d0f6a70b-0483-4a65-dc76-11c3011988aa"
},
"source": [
"model.load_state_dict(torch.load('reconstruction-model.pt'))\n",
"\n",
"test_loss = evaluate(model, test_iterator, criterion)\n",
"\n",
"print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')"
],
"execution_count": 25,
"outputs": [
{
"output_type": "stream",
"text": [
"100%|██████████| 8/8 [00:00<00:00, 6.74it/s]"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"| Test Loss: 0.357 | Test PPL: 1.430 |\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ZSsJlQ1WGeNm",
"colab_type": "code",
"colab": {}
},
"source": [
"def get_embedding_values(model, iterator):\n",
" \n",
" model.eval()\n",
" \n",
" epoch_loss = 0\n",
" \n",
" for batch in iterator:\n",
" \n",
" src = batch.trg\n",
" trg = batch.trg\n",
" \n",
" outputs, hidden = model.encoder(src)\n",
"\n",
" return outputs"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "WUDJP8svNqrP",
"colab": {}
},
"source": [
"embeddings = get_embedding_values(model, test_iterator)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "WkqNuDW-NqrY",
"colab": {}
},
"source": [
"embeddings = torch.flatten(embeddings).detach().cpu().numpy()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "wY-lsXc7Nqrd",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 279
},
"outputId": "52e152b2-f647-4db1-cea2-2c3772c22fcd"
},
"source": [
"plt.hist(embeddings, bins = 50);\n",
"plt.xlabel('Embedding Values');\n",
"plt.ylabel('N');"
],
"execution_count": 36,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZEAAAEGCAYAAACkQqisAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAXiElEQVR4nO3df7hlVX3f8fdHCPir8ktKEdCBOGpQ\nnwjOgzS0xohFwNZBizpG46goNWJiYnwq1DQkGiKmTag2VqWAglpQMSmTgBlHYJrayI9BiAiIjKAy\nBGUiiBIjCH77x15XD3funZm75p575859v57nPGfvtdfe53v2ufd8z1p777VTVUiS1OMR8x2AJGnh\nMolIkrqZRCRJ3UwikqRuJhFJUred5zuAufb4xz++lixZMt9hSNKCcc011/xDVe091bJFl0SWLFnC\nunXr5jsMSVowknxzumV2Z0mSuplEJEndTCKSpG4mEUlSN5OIJKmbSUSS1M0kIknqZhKRJHUziUiS\nui26K9al+bbk5IunLP/G6S+a40ikbWdLRJLUzSQiSepmEpEkdTOJSJK6mUQkSd1MIpKkbiYRSVI3\nrxORtnNeV6LtmS0RSVI3WyLSdmK6Foe0PbMlIknqZhKRJHUziUiSuplEJEndTCKSpG6enSWNiWdb\naTGwJSJJ6mYSkSR1M4lIkrqZRCRJ3UwikqRuJhFJUjeTiCSpm0lEktTNJCJJ6mYSkSR1M4lIkrqN\ndeysJL8NvAEo4HrgdcC+wAXAXsA1wK9V1QNJdgXOA54NfBd4RVV9o23nFOAE4CHgN6tqdSs/Gngf\nsBNwVlWdPs73I21PvPe6tgdja4kk2Q/4TWBZVT2D4Yt+BfBe4IyqejJwD0NyoD3f08rPaPVIcnBb\n7+nA0cD/SLJTkp2ADwDHAAcDr2x1JUlzZNzdWTsDj0qyM/Bo4E7g+cCFbfm5wHFtenmbpy0/Mkla\n+QVVdX9V3QasBw5rj/VVdWtVPcDQulk+5vcjSRoxtiRSVXcA/xX4FkPyuJeh++p7VfVgq7YB2K9N\n7wfc3tZ9sNXfa7R80jrTlW8iyYlJ1iVZt3Hjxm1/c5IkYLzdWXswtAwOBJ4APIahO2rOVdWZVbWs\nqpbtvffe8xGCJO2Qxtmd9QLgtqraWFU/Bv4cOALYvXVvAewP3NGm7wAOAGjLd2M4wP7T8knrTFcu\nSZoj40wi3wIOT/LodmzjSOBG4HLg+FZnJXBRm17V5mnLL6uqauUrkuya5EBgKXAVcDWwNMmBSXZh\nOPi+aozvR5I0ydhO8a2qK5NcCHwJeBC4FjgTuBi4IMkftrKz2ypnAx9Lsh64myEpUFU3JPkUQwJ6\nEDipqh4CSPIWYDXDmV/nVNUN43o/kqRNjfU6kao6FTh1UvGtDGdWTa77I+Bl02znNOC0KcovAS7Z\n9kglST28Yl2S1M0kIknqNtbuLGkxmG74EWkxsCUiSepmEpEkdTOJSJK6mUQkSd08sC7tYLzPiOaS\nLRFJUjeTiCSpm0lEktTNJCJJ6mYSkSR1M4lIkrqZRCRJ3UwikqRuXmwobSVH65U2ZUtEktTNJCJJ\n6mYSkSR1M4lIkrqZRCRJ3UwikqRuJhFJUjeTiCSpm0lEktTNJCJJ6mYSkSR1M4lIkrqZRCRJ3Uwi\nkqRuJhFJUjfvJyJN4n1DpK1nEpEWiemS4zdOf9EcR6Idid1ZkqRuY00iSXZPcmGSrya5Kcm/TLJn\nkjVJbmnPe7S6SfL+JOuTfDnJoSPbWdnq35Jk5Uj5s5Nc39Z5f5KM8/1Ikh5u3C2R9wF/XVVPA34R\nuAk4Gbi0qpYCl7Z5gGOApe1xIvBBgCR7AqcCzwEOA06dSDytzhtH1jt6zO9HkjRibEkkyW7Ac4Gz\nAarqgar6HrAcOLdVOxc4rk0vB86rwRXA7kn2BV4IrKmqu6vqHmANcHRb9riquqKqCjhvZFuSpDkw\nzpbIgcBG4CNJrk1yVpLHAPtU1Z2tzreBfdr0fsDtI+tvaGWbK98wRfkmkpyYZF2SdRs3btzGtyVJ\nmjDOJLIzcCjwwao6BPhHftZ1BUBrQdQYY5h4nTOrallVLdt7773H/XKStGiMM4lsADZU1ZVt/kKG\npPKd1hVFe76rLb8DOGBk/f1b2ebK95+iXJI0R8aWRKrq28DtSZ7aio4EbgRWARNnWK0ELmrTq4DX\ntLO0Dgfubd1eq4GjkuzRDqgfBaxuy76f5PB2VtZrRrYlSZoD477Y8DeATyTZBbgVeB1D4vpUkhOA\nbwIvb3UvAY4F1gM/bHWpqruTvBu4utV7V1Xd3abfDHwUeBTw2faQJM2RsSaRqroOWDbFoiOnqFvA\nSdNs5xzgnCnK1wHP2MYwJUmdvGJdktTNJCJJ6mYSkSR1M4lIkrqZRCRJ3UwikqRuJhFJUjfvbKhF\nyVvgSrPDJCItct42V9vC7ixJUjeTiCSpm0lEktTNJCJJ6mYSkSR1M4lIkrp5iq+kKW3uWhpP/9UE\nWyKSpG6bbYkk+b3NLK6qevcsxyNJWkC21J31j1OUPRp4A7AXYBKRpEVss0mkqv5kYjrJPwPeCrwe\nuAD4k+nWkyQtDls8sJ5kT+BtwKuAc4FDq+qecQcmSdr+bemYyH8BXgqcCTyzqu6bk6gkSQvCls7O\n+h3gCcDvAn+f5Pvt8YMk3x9/eJKk7dmWjol4CrAkaVomCUlSN69Y1w7NOxhK42VLRJLUzSQiSepm\nEpEkdfOYiKQZm+5Yk6P7Lj62RCRJ3UwikqRuJhFJUjeTiCSpm0lEktRt7EkkyU5Jrk3yV23+wCRX\nJlmf5JNJdmnlu7b59W35kpFtnNLKb07ywpHyo1vZ+iQnj/u9SJIebi5O8X0rcBPwuDb/XuCMqrog\nyYeAE4APtud7qurJSVa0eq9IcjCwAng6w4jCn0/ylLatDwD/BtgAXJ1kVVXdOAfvSdsZhzeR5sdY\nWyJJ9gdeBJzV5gM8H7iwVTkXOK5NL2/ztOVHtvrLgQuq6v6qug1YDxzWHuur6taqeoDhbovLx/l+\nJEkPN+7urP8G/EfgJ21+L+B7VfVgm98A7Nem9wNuB2jL7231f1o+aZ3pyjeR5MQk65Ks27hx47a+\nJ0lSM7YkkuTfAndV1TXjeo2tVVVnVtWyqlq29957z3c4krTDGOcxkSOAFyc5FngkwzGR9wG7J9m5\ntTb2B+5o9e8ADgA2JNkZ2A347kj5hNF1piuXJM2BsSWRqjoFOAUgyfOAt1fVq5J8Gjie4RjGSuCi\ntsqqNv/Ftvyyqqokq4D/leRPGQ6sLwWuAgIsTXIgQ/JYAfzquN6PpC1zTK3FZz4GYHwHcEGSPwSu\nBc5u5WcDH0uyHribISlQVTck+RRwI/AgcFJVPQSQ5C3AamAn4JyqumFO34kkLXJzkkSqai2wtk3f\nynBm1eQ6PwJeNs36pwGnTVF+CXDJLIYqSZoBr1iXJHXzfiKSxs5jJTsuWyKSpG4mEUlSN5OIJKmb\nx0S0oDjQorR9sSUiSepmS0TbJVsc0sJgS0SS1M0kIknqZhKRJHXzmIikeeOV7AufLRFJUjeTiCSp\nm91ZmleeyistbLZEJEndTCKSpG52Z0na7njW1sJhS0SS1M0kIknqZhKRJHUziUiSuplEJEndPDtL\nc8KLCqUdky0RSVI3WyKaVbY4NE5eP7L9sSUiSepmEpEkdTOJSJK6eUxE0oLnsZL5YxJRFw+gSwKT\niCTtUOa6VeYxEUlSN1siM7AY+13tttJCthj/Z+eaLRFJUrexJZEkByS5PMmNSW5I8tZWvmeSNUlu\nac97tPIkeX+S9Um+nOTQkW2tbPVvSbJypPzZSa5v67w/Scb1fiRJmxpnS+RB4Heq6mDgcOCkJAcD\nJwOXVtVS4NI2D3AMsLQ9TgQ+CEPSAU4FngMcBpw6kXhanTeOrHf0GN+PJGmSsR0Tqao7gTvb9A+S\n3ATsBywHnteqnQusBd7Rys+rqgKuSLJ7kn1b3TVVdTdAkjXA0UnWAo+rqita+XnAccBnx/WeJO0Y\nPFYye+bkwHqSJcAhwJXAPi3BAHwb2KdN7wfcPrLahla2ufINU5RP9fonMrRueOITn9j/RnZgHkCX\n1GPsSSTJY4HPAL9VVd8fPWxRVZWkxh1DVZ0JnAmwbNmysb+epIXJFsrMjTWJJPk5hgTyiar681b8\nnST7VtWdrbvqrlZ+B3DAyOr7t7I7+Fn310T52la+/xT1NQ1bG5Jm2zjPzgpwNnBTVf3pyKJVwMQZ\nViuBi0bKX9PO0jocuLd1e60GjkqyRzugfhSwui37fpLD22u9ZmRbkqQ5MM6WyBHArwHXJ7mulf0n\n4HTgU0lOAL4JvLwtuwQ4FlgP/BB4HUBV3Z3k3cDVrd67Jg6yA28GPgo8iuGAugfVJWkOjfPsrC8A\n0123ceQU9Qs4aZptnQOcM0X5OuAZ2xDmDsluK2l2be5/arEfL3HYkwXMZCHNv8V+MN5hTyRJ3WyJ\njNFs/UKxxSEtPDP9v12o3wsmkVmwvXyYkhauhdotZhKZByYdSVtre/++8JiIJKmbSUSS1M0kIknq\nZhKRJHUziUiSuplEJEndTCKSpG4mEUlSN5OIJKmbSUSS1M0kIknqZhKRJHUziUiSuplEJEndTCKS\npG4mEUlSN5OIJKmbSUSS1M0kIknqZhKRJHUziUiSuplEJEndTCKSpG4mEUlSN5OIJKmbSUSS1M0k\nIknqZhKRJHUziUiSuplEJEndFnwSSXJ0kpuTrE9y8nzHI0mLyYJOIkl2Aj4AHAMcDLwyycHzG5Uk\nLR4LOokAhwHrq+rWqnoAuABYPs8xSdKisfN8B7CN9gNuH5nfADxncqUkJwInttn7ktzc+XqPB/6h\nc91xMq6ZMa6ZMa6Z2S7jynu3Ka4nTbdgoSeRrVJVZwJnbut2kqyrqmWzENKsMq6ZMa6ZMa6ZWWxx\nLfTurDuAA0bm929lkqQ5sNCTyNXA0iQHJtkFWAGsmueYJGnRWNDdWVX1YJK3AKuBnYBzquqGMb7k\nNneJjYlxzYxxzYxxzcyiiitVNY7tSpIWgYXenSVJmkcmEUlSN5PIJEleluSGJD9JMu3pcNMNt9IO\n8l/Zyj/ZDvjPRlx7JlmT5Jb2vMcUdX4lyXUjjx8lOa4t+2iS20aWPWuu4mr1Hhp57VUj5fO5v56V\n5Ivt8/5ykleMLJvV/bWl4XmS7Nre//q2P5aMLDulld+c5IXbEkdHXG9LcmPbP5cmedLIsik/0zmK\n67VJNo68/htGlq1sn/stSVbOcVxnjMT0tSTfG1k2lv2V5JwkdyX5yjTLk+T9LeYvJzl0ZNm276uq\n8jHyAH4BeCqwFlg2TZ2dgK8DBwG7AH8HHNyWfQpY0aY/BPz6LMX1x8DJbfpk4L1bqL8ncDfw6Db/\nUeD4MeyvrYoLuG+a8nnbX8BTgKVt+gnAncDus72/Nvf3MlLnzcCH2vQK4JNt+uBWf1fgwLadneYw\nrl8Z+Rv69Ym4NveZzlFcrwX+bIp19wRubc97tOk95iquSfV/g+Fkn3Hvr+cChwJfmWb5scBngQCH\nA1fO5r6yJTJJVd1UVVu6on3K4VaSBHg+cGGrdy5w3CyFtrxtb2u3ezzw2ar64Sy9/nRmGtdPzff+\nqqqvVdUtbfrvgbuAvWfp9UdtzfA8o/FeCBzZ9s9y4IKqur+qbgPWt+3NSVxVdfnI39AVDNdijdu2\nDGf0QmBNVd1dVfcAa4Cj5ymuVwLnz9JrT6uq/obhB+N0lgPn1eAKYPck+zJL+8ok0meq4Vb2A/YC\nvldVD04qnw37VNWdbfrbwD5bqL+CTf+AT2vN2TOS7DrHcT0yybokV0x0sbEd7a8khzH8uvz6SPFs\n7a/p/l6mrNP2x70M+2dr1h1nXKNOYPhFO2Gqz3Qu4/r37fO5MMnERcfbxf5q3X4HApeNFI9rf23J\ndHHPyr5a0NeJ9EryeeBfTLHonVV10VzHM2FzcY3OVFUlmfbc7PYr45kM189MOIXhy3QXhvPF3wG8\naw7jelJV3ZHkIOCyJNczfFF2m+X99TFgZVX9pBV3768dUZJXA8uAXx4p3uQzraqvT72FWfeXwPlV\ndX+S/8DQinv+HL321lgBXFhVD42Uzef+GptFmUSq6gXbuInphlv5LkNTcef2a3JGw7BsLq4k30my\nb1Xd2b707trMpl4O/EVV/Xhk2xO/yu9P8hHg7XMZV1Xd0Z5vTbIWOAT4DPO8v5I8DriY4QfEFSPb\n7t5fU9ia4Xkm6mxIsjOwG8Pf0ziH9tmqbSd5AUNi/uWqun+ifJrPdDa+FLcYV1V9d2T2LIZjYBPr\nPm/SumtnIaatimvECuCk0YIx7q8tmS7uWdlXdmf1mXK4lRqOVl3OcDwCYCUwWy2bVW17W7PdTfpi\n2xfpxHGI44Apz+QYR1xJ9pjoDkryeOAI4Mb53l/ts/sLhv7iCyctm839tTXD84zGezxwWds/q4AV\nGc7eOhBYCly1DbHMKK4khwAfBl5cVXeNlE/5mc5hXPuOzL4YuKlNrwaOavHtARzFw1vkY42rxfY0\nhgPVXxwpG+f+2pJVwGvaWVqHA/e2H0mzs6/GcbbAQn4AL2HoG7wf+A6wupU/AbhkpN6xwNcYfkm8\nc6T8IIZ/8vXAp4FdZymuvYBLgVuAzwN7tvJlwFkj9ZYw/MJ4xKT1LwOuZ/gy/Djw2LmKC/il9tp/\n155P2B72F/Bq4MfAdSOPZ41jf03198LQPfbiNv3I9v7Xt/1x0Mi672zr3QwcM8t/71uK6/Pt/2Bi\n/6za0mc6R3G9B7ihvf7lwNNG1n1924/rgdfNZVxt/veB0yetN7b9xfCD8c72t7yB4djVm4A3teVh\nuHnf19trLxtZd5v3lcOeSJK62Z0lSepmEpEkdTOJSJK6mUQkSd1MIpKkbiYR7XAmjZZ6XaYYbXUz\n6z4vyV9tw2tPu36Sb7RrBEjyt72vMbK9Ryf5brtgcrT8f2dkROKZxCjN1KK8Yl07vH+qqlkZ6n5c\nquqXZmEbP0yymuHapnMBkuwG/CvgV7d1+9LWsCWiRaO1BN7TWifrkhyaZHWSryd500jVxyW5OMN9\nIz6U5BFt/aMy3H/kS0k+neSxrfzoJF9N8iXgpSOvt1eSz2W4X8lZDBd9TSy7rz0/L8naDIMIfjXJ\nJ9pV8iQ5tpVdk+F+EFO1Hs5nuHJ6wksYLpD9YZLDWrzXJvnbJE+dYp/8fpK3j8x/Je1eJkleneSq\ntr8+nGSn9vhoq3d9kt+e4cegHYxJRDuiR03qzhrt2vlWa6X8X9o9QxjusfAHI3UOY7gXxMHAzwMv\nbd1Qvwu8oKoOBdYBb0vySOB/Av8OeDYPHxDyVOALVfV0huFVnjhNvIcAv9Ve7yDgiLbdDzNcof5s\nph+ifjVwaJK92vzo6M1fBf51VR0C/B7wR9NsYxNJfgF4BXBE218PAa8CngXsV1XPqKpnAh/Z2m1q\nx2R3lnZEm+vOmhjr6HqGoUx+APwgyf1Jdm/LrqqqWwGSnM/QPfQjhi/5/9caCrswjI30NOC2avcl\nSfJx4MS2nefSWiZVdXGSe6aJ6aqq2tDWv45h6Jr7gFtruIcIDInhxMkrVtUDGe6Sd3ySzzAkpInx\nj3YDzk2yFCjg56Z5/akcyZAUr27v91EMg1j+JXBQkv/OMHDl52awTe2ATCJabCZGof3JyPTE/MT/\nw+SxgIqhK2pNVb1ydEFm5zbDo3E8xMz/L88H/jNDjBfVz0ZvfjdweVW9pHVRrZ1i3Qd5eI/EI9tz\ngHOr6pTJKyT5RYYbGr2JYcTo188wXu1A7M6SNnVYG6n1EQxdOl9guKvfEUmeDJDkMUmewtBltCTJ\nz7d1R5PM39AOcCc5hmFk1611M8Mv/iVtftqzrRiSw1KGocdHR2/ejZ8NVf7aadb9BsOtVclw7+0D\nW/mlDK2bf96W7ZnkSa1b7xFV9RmG7r1DN92kFhOTiHZEk4+JnD7D9a8G/oxhePHbGO7NspHhi/j8\nJF+mdWVV1Y8YupkubgfWR+9b8gfAc5PcwNCt9a2tDaCq/onhvut/neQa4AdMcxOvGm6kdSHDyMX/\nZ2TRHwPvSXIt07duPgPs2WJ8C8MItVTVjQxJ4nPt/a4B9mW4893a1u32cYabd2kRcxRfaTuV5LFV\ndV87W+sDwC1VdcZ8xyWNsiUibb/e2H7x38DQNfXheY5H2oQtEUlSN1sikqRuJhFJUjeTiCSpm0lE\nktTNJCJJ6vb/ATxo7o4vLAGqAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": []
}
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment