Created
July 22, 2020 05:50
-
-
Save Dipeshpal/5d55e7016c5f4fcda4e5b5c459f8436d to your computer and use it in GitHub Desktop.
prediction (1).ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "prediction (1).ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"include_colab_link": true | |
}, | |
"hide_input": false, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.7" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/Dipeshpal/5d55e7016c5f4fcda4e5b5c459f8436d/prediction-1.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "weuU8H1Dy-y7", | |
"colab": {} | |
}, | |
"source": [ | |
"# from google.colab import drive\n", | |
"# drive.mount('/content/drive')" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "8gSy_4TBzJOT", | |
"colab": {} | |
}, | |
"source": [ | |
"import os\n", | |
"\n", | |
"# os.chdir(\"/content/drive/My Drive/Projects/lang_translator_pytorch_attention\")" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "Eu-N7bKKy5pq", | |
"colab": {} | |
}, | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.optim as optim\n", | |
"\n", | |
"import torchtext\n", | |
"from torchtext import data\n", | |
"# from torchtext.datasets import Multi30k\n", | |
"from torchtext.data import Field, BucketIterator\n", | |
"\n", | |
"# import matplotlib.pyplot as plt\n", | |
"# import matplotlib.ticker as ticker\n", | |
"\n", | |
"import spacy\n", | |
"import numpy as np\n", | |
"\n", | |
"import random\n", | |
"import math\n", | |
"import time\n", | |
"import os" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "C-ejYsQay5pt", | |
"colab": {} | |
}, | |
"source": [ | |
"SEED = 1234\n", | |
"\n", | |
"random.seed(SEED)\n", | |
"np.random.seed(SEED)\n", | |
"torch.manual_seed(SEED)\n", | |
"torch.cuda.manual_seed(SEED)\n", | |
"torch.backends.cudnn.deterministic = True" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "TGMzD2lXEqGo", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Basic setting up things" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "evhaQnH2y5pw", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 159 | |
}, | |
"outputId": "7199c0c8-4aaa-4db6-8cba-85f762f92c97" | |
}, | |
"source": [ | |
"BASE = \"dataset/\"\n", | |
"\n", | |
"folder_name = input(\"Enter folder name (Ex: en-es | None for en-de): \")\n", | |
"\n", | |
"MY_PATH = f\"{BASE}{folder_name}\"\n", | |
"\n", | |
"if MY_PATH == None:\n", | |
" lang1_name = \".en\"\n", | |
" lang2_name = \".de\"\n", | |
"else:\n", | |
" lang1_name = input(\"Enter Lang1 Name (Ex: en): \")\n", | |
" lang1_name = \".\"+lang1_name\n", | |
" lang2_name = input(\"Enter Lang2 Name (Ex: es): \")\n", | |
" lang2_name = \".\"+lang2_name\n", | |
" \n", | |
"BATCH_SIZE = int(input(\"Batch Size: (128 Recommended): \"))\n", | |
"N_EPOCHS = int(input(\"Number of Epochs (50 Recommended): \"))\n", | |
"\n", | |
"tensor_print = input(\"Do you want to print the tensor summary (Type yes or No): \")\n", | |
"summary = input(\"Do you want to print model summary (Type yes or No): \")\n", | |
"\n", | |
"model_name = f\"models/{lang1_name[1:]}-{lang2_name[1:]}/\"+lang1_name[1:]+'-'+lang2_name[1:]\n", | |
"print(\"Model Name: \", model_name)\n", | |
"# print(f\"Language 1: {lang1_name}, Language 2: {lang2_name}\")" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Enter folder name (Ex: en-es | None for en-de): en-fr\n", | |
"Enter Lang1 Name (Ex: en): en\n", | |
"Enter Lang2 Name (Ex: es): fr\n", | |
"Batch Size: (128 Recommended): 128\n", | |
"Number of Epochs (50 Recommended): 1\n", | |
"Do you want to print the tensor summary (Type yes or No): yes\n", | |
"Do you want to print model summary (Type yes or No): yes\n", | |
"Model Name: models/en-fr/en-fr\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "KulkumpFy5pz", | |
"colab": {} | |
}, | |
"source": [ | |
"def create_dirs(dir_path):\n", | |
" try:\n", | |
" os.mkdir(dir_path)\n", | |
" except:\n", | |
" pass\n", | |
"\n", | |
"create_dirs(\"models/\")\n", | |
"create_dirs(f\"models/{lang1_name[1:]}-{lang2_name[1:]}/\")" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "o-Uh0sxSy5p1", | |
"colab": {} | |
}, | |
"source": [ | |
"if lang2_name == \".hi\":\n", | |
" print(\"Hindi Tokenizater Installing\")\n", | |
" os.system(\"pip install inltk\")\n", | |
" from inltk.inltk import setup\n", | |
" from inltk.inltk import tokenize\n", | |
" setup('hi')" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "SgynozXnzUIr", | |
"colab": {} | |
}, | |
"source": [ | |
"# !python -m spacy download en\n", | |
"# !python -m spacy download fr" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "e5qRUlEny5p3", | |
"colab": {} | |
}, | |
"source": [ | |
"# https://spacy.io/usage/models#languages\n", | |
"spacy_lang1 = spacy.load(lang1_name[1:])\n", | |
"\n", | |
"# if lang2_name != \".hi\":\n", | |
"# print(f\"Second Language is {lang2_name[1:]}\")\n", | |
"# # os.system(f\"python -m spacy download {lang2_name[1:]}\")\n", | |
"spacy_lang2 = spacy.load(lang2_name[1:])" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "-K5P4zvQy5p5", | |
"colab": {} | |
}, | |
"source": [ | |
"def tokenize_lang1(text):\n", | |
" return [tok.text for tok in spacy_lang1.tokenizer(text)]\n", | |
"\n", | |
"def tokenize_lang2(text):\n", | |
" if lang2_name == \".hi\":\n", | |
" return tokenize(text, \"hi\")\n", | |
" else: \n", | |
" return [tok.text for tok in spacy_lang2.tokenizer(text)]" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "VVtY7iiNEqHG", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Just checking" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "2i8GsA4Fy5p9", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 52 | |
}, | |
"outputId": "594d9cc4-3af6-4471-9466-c8894bce1d10" | |
}, | |
"source": [ | |
"if lang2_name == \".hi\":\n", | |
" print(\"Your language is Hindi, Just checking Tokenization\")\n", | |
" print(tokenize_lang2(\"प्राचीन काल में विक्रमादित्य नाम के एक आदर्श राजा हुआ करते थे।\"))\n", | |
"else:\n", | |
" print(\"Your primary language is English, Just checking Tokenization\")\n", | |
" print(tokenize_lang2(\"You language is English, Just checking Tokenization\"))" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Your primary language is English, Just checking Tokenization\n", | |
"['You', 'language', 'is', 'English', ',', 'Just', 'checking', 'Tokenization']\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "7H1WRW0Fy5qA", | |
"colab": {} | |
}, | |
"source": [ | |
"SRC = Field(tokenize = tokenize_lang1, \n", | |
" init_token = '<sos>', \n", | |
" eos_token = '<eos>', \n", | |
" lower = True, \n", | |
" batch_first = True)\n", | |
"\n", | |
"TRG = Field(tokenize = tokenize_lang2, \n", | |
" init_token = '<sos>', \n", | |
" eos_token = '<eos>', \n", | |
" lower = True, \n", | |
" batch_first = True)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"colab_type": "text", | |
"id": "qfZ-pcGAy5qD" | |
}, | |
"source": [ | |
"# Multi30 class with some customization" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"code_folding": [ | |
7, | |
51 | |
], | |
"colab_type": "code", | |
"id": "xJDOXLyZy5qD", | |
"colab": {} | |
}, | |
"source": [ | |
"import os\n", | |
"import xml.etree.ElementTree as ET\n", | |
"import glob\n", | |
"import io\n", | |
"import codecs\n", | |
"\n", | |
"\n", | |
"class TranslationDataset(data.Dataset):\n", | |
" \n", | |
" @staticmethod\n", | |
" def sort_key(ex):\n", | |
" return data.interleave_keys(len(ex.src), len(ex.trg))\n", | |
"\n", | |
" def __init__(self, path, exts, fields, **kwargs):\n", | |
" \n", | |
" if not isinstance(fields[0], (tuple, list)):\n", | |
" fields = [('src', fields[0]), ('trg', fields[1])]\n", | |
"\n", | |
" src_path, trg_path = tuple(os.path.expanduser(path + x) for x in exts)\n", | |
"\n", | |
" examples = []\n", | |
" with io.open(src_path, mode='r', encoding='utf-8') as src_file, \\\n", | |
" io.open(trg_path, mode='r', encoding='utf-8') as trg_file:\n", | |
" for src_line, trg_line in zip(src_file, trg_file):\n", | |
" src_line, trg_line = src_line.strip(), trg_line.strip()\n", | |
" if src_line != '' and trg_line != '':\n", | |
" examples.append(data.Example.fromlist(\n", | |
" [src_line, trg_line], fields))\n", | |
"\n", | |
" super(TranslationDataset, self).__init__(examples, fields, **kwargs)\n", | |
"\n", | |
"\n", | |
" @classmethod\n", | |
" def splits(cls, exts, fields, path=None, root='dataset',\n", | |
" train='train', validation='val', test='test', **kwargs):\n", | |
" \n", | |
" if path is None or path is \"en-de\":\n", | |
" path = cls.download(root)\n", | |
"\n", | |
" train_data = None if train is None else cls(\n", | |
" os.path.join(path, train), exts, fields, **kwargs)\n", | |
" val_data = None if validation is None else cls(\n", | |
" os.path.join(path, validation), exts, fields, **kwargs)\n", | |
" test_data = None if test is None else cls(\n", | |
" os.path.join(path, test), exts, fields, **kwargs)\n", | |
" return tuple(d for d in (train_data, val_data, test_data)\n", | |
" if d is not None)\n", | |
" \n", | |
" \n", | |
"\n", | |
"\n", | |
"class CustomMulti30(TranslationDataset):\n", | |
" \n", | |
" urls = ['http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz',\n", | |
" 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz',\n", | |
" 'http://www.quest.dcs.shef.ac.uk/'\n", | |
" 'wmt17_files_mmt/mmt_task1_test2016.tar.gz']\n", | |
" \n", | |
" \n", | |
" name = \"en-de\"\n", | |
" dirname = ''\n", | |
"\n", | |
" @classmethod\n", | |
" def splits(cls, exts, fields, root='dataset',\n", | |
" train='train', validation='val', test='test', **kwargs):\n", | |
"\n", | |
" if 'path' not in kwargs:\n", | |
" expected_folder = os.path.join(root, cls.name)\n", | |
" path = expected_folder if os.path.exists(expected_folder) else None\n", | |
" else:\n", | |
" path = kwargs['path']\n", | |
" del kwargs['path']\n", | |
" \n", | |
" if path == None:\n", | |
" train = 'train'\n", | |
" test = 'test2016'\n", | |
" val = 'val'\n", | |
" exts = ('.en', '.de')\n", | |
"\n", | |
" return super(CustomMulti30, cls).splits(\n", | |
" exts, fields, path, root, train, validation, test, **kwargs)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "OmioHrSBy5qF", | |
"colab": {} | |
}, | |
"source": [ | |
"train_data, valid_data, test_data = CustomMulti30.splits(exts = (lang1_name, lang2_name), fields = (SRC, TRG), root='dataset', path=MY_PATH,\n", | |
" train='train', validation='val', test='test')" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "0CVjpeivy5qI", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"outputId": "382f97c5-453b-40a9-8654-3b3a26a8bcdc" | |
}, | |
"source": [ | |
"len(train_data), len(valid_data), len(test_data)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(119144, 38297, 12766)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 14 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "r43EOIqCy5qL", | |
"colab": {} | |
}, | |
"source": [ | |
"SRC.build_vocab(train_data, min_freq = 2)\n", | |
"TRG.build_vocab(train_data, min_freq = 2)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "Dz97f-sfy5qO", | |
"colab": {} | |
}, | |
"source": [ | |
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "XJRv8lwHy5qQ", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"outputId": "ce9b17c1-a2bd-4416-dbdd-95306ed8a4c2" | |
}, | |
"source": [ | |
"device" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"device(type='cpu')" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 17 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "bMUcegoTy5qS", | |
"colab": {} | |
}, | |
"source": [ | |
"train_iterator, valid_iterator, test_iterator = BucketIterator.splits(\n", | |
" (train_data, valid_data, test_data), \n", | |
" batch_size = BATCH_SIZE,\n", | |
" device = device)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"colab_type": "text", | |
"id": "fpgbgauey5qW" | |
}, | |
"source": [ | |
"# Model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"code_folding": [ | |
0, | |
53, | |
95, | |
170, | |
194, | |
255, | |
309 | |
], | |
"colab_type": "code", | |
"id": "3ZNp_tqqy5qW", | |
"colab": {} | |
}, | |
"source": [ | |
"class Encoder(nn.Module):\n", | |
" def __init__(self, \n", | |
" input_dim, \n", | |
" hid_dim, \n", | |
" n_layers, \n", | |
" n_heads, \n", | |
" pf_dim,\n", | |
" dropout, \n", | |
" device,\n", | |
" max_length = 100):\n", | |
" super().__init__()\n", | |
"\n", | |
" self.device = device\n", | |
" \n", | |
" self.tok_embedding = nn.Embedding(input_dim, hid_dim)\n", | |
" self.pos_embedding = nn.Embedding(max_length, hid_dim)\n", | |
" \n", | |
" self.layers = nn.ModuleList([EncoderLayer(hid_dim, \n", | |
" n_heads, \n", | |
" pf_dim,\n", | |
" dropout, \n", | |
" device) \n", | |
" for _ in range(n_layers)])\n", | |
" \n", | |
" self.dropout = nn.Dropout(dropout)\n", | |
" \n", | |
" self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)\n", | |
" \n", | |
" def forward(self, src, src_mask):\n", | |
" \n", | |
" #src = [batch size, src len]\n", | |
" #src_mask = [batch size, src len]\n", | |
" \n", | |
" batch_size = src.shape[0]\n", | |
" src_len = src.shape[1]\n", | |
" \n", | |
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", | |
"\n", | |
" pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)\n", | |
" \n", | |
" #pos = [batch size, src len]\n", | |
" \n", | |
" src = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos))\n", | |
" \n", | |
" #src = [batch size, src len, hid dim]\n", | |
" \n", | |
" for layer in self.layers:\n", | |
" src = layer(src, src_mask)\n", | |
" \n", | |
" #src = [batch size, src len, hid dim]\n", | |
" \n", | |
" return src\n", | |
" \n", | |
" \n", | |
" \n", | |
"class EncoderLayer(nn.Module):\n", | |
" def __init__(self, \n", | |
" hid_dim, \n", | |
" n_heads, \n", | |
" pf_dim, \n", | |
" dropout, \n", | |
" device):\n", | |
" super().__init__()\n", | |
" \n", | |
" self.self_attn_layer_norm = nn.LayerNorm(hid_dim)\n", | |
" self.ff_layer_norm = nn.LayerNorm(hid_dim)\n", | |
" self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)\n", | |
" self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, \n", | |
" pf_dim, \n", | |
" dropout)\n", | |
" self.dropout = nn.Dropout(dropout)\n", | |
" \n", | |
" def forward(self, src, src_mask):\n", | |
" \n", | |
" #src = [batch size, src len, hid dim]\n", | |
" #src_mask = [batch size, src len]\n", | |
" \n", | |
" #self attention\n", | |
" _src, _ = self.self_attention(src, src, src, src_mask)\n", | |
" \n", | |
" #dropout, residual connection and layer norm\n", | |
" src = self.self_attn_layer_norm(src + self.dropout(_src))\n", | |
" \n", | |
" #src = [batch size, src len, hid dim]\n", | |
" \n", | |
" #positionwise feedforward\n", | |
" _src = self.positionwise_feedforward(src)\n", | |
" \n", | |
" #dropout, residual and layer norm\n", | |
" src = self.ff_layer_norm(src + self.dropout(_src))\n", | |
" \n", | |
" #src = [batch size, src len, hid dim]\n", | |
" \n", | |
" return src\n", | |
" \n", | |
" \n", | |
" \n", | |
"class MultiHeadAttentionLayer(nn.Module):\n", | |
" def __init__(self, hid_dim, n_heads, dropout, device):\n", | |
" super().__init__()\n", | |
" \n", | |
" assert hid_dim % n_heads == 0\n", | |
" \n", | |
" self.hid_dim = hid_dim\n", | |
" self.n_heads = n_heads\n", | |
" self.head_dim = hid_dim // n_heads\n", | |
" \n", | |
" self.fc_q = nn.Linear(hid_dim, hid_dim)\n", | |
" self.fc_k = nn.Linear(hid_dim, hid_dim)\n", | |
" self.fc_v = nn.Linear(hid_dim, hid_dim)\n", | |
" \n", | |
" self.fc_o = nn.Linear(hid_dim, hid_dim)\n", | |
" \n", | |
" self.dropout = nn.Dropout(dropout)\n", | |
" \n", | |
" self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)\n", | |
" \n", | |
" def forward(self, query, key, value, mask = None):\n", | |
" \n", | |
" batch_size = query.shape[0]\n", | |
" \n", | |
" #query = [batch size, query len, hid dim]\n", | |
" #key = [batch size, key len, hid dim]\n", | |
" #value = [batch size, value len, hid dim]\n", | |
" \n", | |
" Q = self.fc_q(query)\n", | |
" K = self.fc_k(key)\n", | |
" V = self.fc_v(value)\n", | |
" \n", | |
" #Q = [batch size, query len, hid dim]\n", | |
" #K = [batch size, key len, hid dim]\n", | |
" #V = [batch size, value len, hid dim]\n", | |
" \n", | |
" Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)\n", | |
" K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)\n", | |
" V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)\n", | |
" \n", | |
" #Q = [batch size, n heads, query len, head dim]\n", | |
" #K = [batch size, n heads, key len, head dim]\n", | |
" #V = [batch size, n heads, value len, head dim]\n", | |
" \n", | |
" energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale\n", | |
" \n", | |
" #energy = [batch size, n heads, query len, key len]\n", | |
" \n", | |
" if mask is not None:\n", | |
" energy = energy.masked_fill(mask == 0, -1e10)\n", | |
" \n", | |
" attention = torch.softmax(energy, dim = -1)\n", | |
" \n", | |
" #attention = [batch size, n heads, query len, key len]\n", | |
" \n", | |
" x = torch.matmul(self.dropout(attention), V)\n", | |
" \n", | |
" #x = [batch size, n heads, query len, head dim]\n", | |
" \n", | |
" x = x.permute(0, 2, 1, 3).contiguous()\n", | |
" \n", | |
" #x = [batch size, query len, n heads, head dim]\n", | |
" \n", | |
" x = x.view(batch_size, -1, self.hid_dim)\n", | |
" \n", | |
" #x = [batch size, query len, hid dim]\n", | |
" \n", | |
" x = self.fc_o(x)\n", | |
" \n", | |
" #x = [batch size, query len, hid dim]\n", | |
" \n", | |
" return x, attention\n", | |
" \n", | |
" \n", | |
" \n", | |
"class PositionwiseFeedforwardLayer(nn.Module):\n", | |
" def __init__(self, hid_dim, pf_dim, dropout):\n", | |
" super().__init__()\n", | |
" \n", | |
" self.fc_1 = nn.Linear(hid_dim, pf_dim)\n", | |
" self.fc_2 = nn.Linear(pf_dim, hid_dim)\n", | |
" \n", | |
" self.dropout = nn.Dropout(dropout)\n", | |
" \n", | |
" def forward(self, x):\n", | |
" \n", | |
" #x = [batch size, seq len, hid dim]\n", | |
" \n", | |
" x = self.dropout(torch.relu(self.fc_1(x)))\n", | |
" \n", | |
" #x = [batch size, seq len, pf dim]\n", | |
" \n", | |
" x = self.fc_2(x)\n", | |
" \n", | |
" #x = [batch size, seq len, hid dim]\n", | |
" \n", | |
" return x\n", | |
" \n", | |
" \n", | |
"class Decoder(nn.Module):\n", | |
" def __init__(self, \n", | |
" output_dim, \n", | |
" hid_dim, \n", | |
" n_layers, \n", | |
" n_heads, \n", | |
" pf_dim, \n", | |
" dropout, \n", | |
" device,\n", | |
" max_length = 100):\n", | |
" super().__init__()\n", | |
" \n", | |
" self.device = device\n", | |
" \n", | |
" self.tok_embedding = nn.Embedding(output_dim, hid_dim)\n", | |
" self.pos_embedding = nn.Embedding(max_length, hid_dim)\n", | |
" \n", | |
" self.layers = nn.ModuleList([DecoderLayer(hid_dim, \n", | |
" n_heads, \n", | |
" pf_dim, \n", | |
" dropout, \n", | |
" device)\n", | |
" for _ in range(n_layers)])\n", | |
" \n", | |
" self.fc_out = nn.Linear(hid_dim, output_dim)\n", | |
" \n", | |
" self.dropout = nn.Dropout(dropout)\n", | |
" \n", | |
" self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)\n", | |
" \n", | |
" def forward(self, trg, enc_src, trg_mask, src_mask):\n", | |
" \n", | |
" #trg = [batch size, trg len]\n", | |
" #enc_src = [batch size, src len, hid dim]\n", | |
" #trg_mask = [batch size, trg len]\n", | |
" #src_mask = [batch size, src len]\n", | |
" \n", | |
" batch_size = trg.shape[0]\n", | |
" trg_len = trg.shape[1]\n", | |
" \n", | |
" pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)\n", | |
" \n", | |
" #pos = [batch size, trg len]\n", | |
" \n", | |
" trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))\n", | |
" \n", | |
" #trg = [batch size, trg len, hid dim]\n", | |
" \n", | |
" for layer in self.layers:\n", | |
" trg, attention = layer(trg, enc_src, trg_mask, src_mask)\n", | |
" \n", | |
" #trg = [batch size, trg len, hid dim]\n", | |
" #attention = [batch size, n heads, trg len, src len]\n", | |
" \n", | |
" output = self.fc_out(trg)\n", | |
" \n", | |
" #output = [batch size, trg len, output dim]\n", | |
" \n", | |
" return output, attention\n", | |
" \n", | |
" \n", | |
"class DecoderLayer(nn.Module):\n", | |
" def __init__(self, \n", | |
" hid_dim, \n", | |
" n_heads, \n", | |
" pf_dim, \n", | |
" dropout, \n", | |
" device):\n", | |
" super().__init__()\n", | |
" \n", | |
" self.self_attn_layer_norm = nn.LayerNorm(hid_dim)\n", | |
" self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)\n", | |
" self.ff_layer_norm = nn.LayerNorm(hid_dim)\n", | |
" self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)\n", | |
" self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)\n", | |
" self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, \n", | |
" pf_dim, \n", | |
" dropout)\n", | |
" self.dropout = nn.Dropout(dropout)\n", | |
" \n", | |
" def forward(self, trg, enc_src, trg_mask, src_mask):\n", | |
" \n", | |
" #trg = [batch size, trg len, hid dim]\n", | |
" #enc_src = [batch size, src len, hid dim]\n", | |
" #trg_mask = [batch size, trg len]\n", | |
" #src_mask = [batch size, src len]\n", | |
" \n", | |
" #self attention\n", | |
" _trg, _ = self.self_attention(trg, trg, trg, trg_mask)\n", | |
" \n", | |
" #dropout, residual connection and layer norm\n", | |
" trg = self.self_attn_layer_norm(trg + self.dropout(_trg))\n", | |
" \n", | |
" #trg = [batch size, trg len, hid dim]\n", | |
" \n", | |
" #encoder attention\n", | |
" _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask)\n", | |
" \n", | |
" #dropout, residual connection and layer norm\n", | |
" trg = self.enc_attn_layer_norm(trg + self.dropout(_trg))\n", | |
" \n", | |
" #trg = [batch size, trg len, hid dim]\n", | |
" \n", | |
" #positionwise feedforward\n", | |
" _trg = self.positionwise_feedforward(trg)\n", | |
" \n", | |
" #dropout, residual and layer norm\n", | |
" trg = self.ff_layer_norm(trg + self.dropout(_trg))\n", | |
" \n", | |
" #trg = [batch size, trg len, hid dim]\n", | |
" #attention = [batch size, n heads, trg len, src len]\n", | |
" \n", | |
" return trg, attention\n", | |
" \n", | |
" \n", | |
"class Seq2Seq(nn.Module):\n", | |
" def __init__(self, \n", | |
" encoder, \n", | |
" decoder, \n", | |
" src_pad_idx, \n", | |
" trg_pad_idx, \n", | |
" device):\n", | |
" super().__init__()\n", | |
" \n", | |
" self.encoder = encoder\n", | |
" self.decoder = decoder\n", | |
" self.src_pad_idx = src_pad_idx\n", | |
" self.trg_pad_idx = trg_pad_idx\n", | |
" self.device = device\n", | |
" \n", | |
" def make_src_mask(self, src):\n", | |
" \n", | |
" #src = [batch size, src len]\n", | |
" \n", | |
" src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)\n", | |
"\n", | |
" #src_mask = [batch size, 1, 1, src len]\n", | |
"\n", | |
" return src_mask\n", | |
" \n", | |
" def make_trg_mask(self, trg):\n", | |
" \n", | |
" #trg = [batch size, trg len]\n", | |
" \n", | |
" trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)\n", | |
" \n", | |
" #trg_pad_mask = [batch size, 1, 1, trg len]\n", | |
" \n", | |
" trg_len = trg.shape[1]\n", | |
" \n", | |
" trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device)).bool()\n", | |
" \n", | |
" #trg_sub_mask = [trg len, trg len]\n", | |
" \n", | |
" trg_mask = trg_pad_mask & trg_sub_mask\n", | |
" \n", | |
" #trg_mask = [batch size, 1, trg len, trg len]\n", | |
" \n", | |
" return trg_mask\n", | |
"\n", | |
" def forward(self, src, trg):\n", | |
" \n", | |
" #src = [batch size, src len]\n", | |
" #trg = [batch size, trg len]\n", | |
" \n", | |
" src_mask = self.make_src_mask(src)\n", | |
" trg_mask = self.make_trg_mask(trg)\n", | |
" \n", | |
" #src_mask = [batch size, 1, 1, src len]\n", | |
" #trg_mask = [batch size, 1, trg len, trg len]\n", | |
" \n", | |
" enc_src = self.encoder(src, src_mask)\n", | |
" \n", | |
" #enc_src = [batch size, src len, hid dim]\n", | |
" \n", | |
" output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)\n", | |
" \n", | |
" #output = [batch size, trg len, output dim]\n", | |
" #attention = [batch size, n heads, trg len, src len]\n", | |
" \n", | |
" return output, attention" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"code_folding": [ | |
0 | |
], | |
"colab_type": "code", | |
"id": "ZpUuq8zxy5qY", | |
"colab": {} | |
}, | |
"source": [ | |
"INPUT_DIM = len(SRC.vocab)\n", | |
"OUTPUT_DIM = len(TRG.vocab)\n", | |
"HID_DIM = 256\n", | |
"ENC_LAYERS = 3\n", | |
"DEC_LAYERS = 3\n", | |
"ENC_HEADS = 8\n", | |
"DEC_HEADS = 8\n", | |
"ENC_PF_DIM = 512\n", | |
"DEC_PF_DIM = 512\n", | |
"ENC_DROPOUT = 0.1\n", | |
"DEC_DROPOUT = 0.1\n", | |
"\n", | |
"enc = Encoder(INPUT_DIM, \n", | |
" HID_DIM, \n", | |
" ENC_LAYERS, \n", | |
" ENC_HEADS, \n", | |
" ENC_PF_DIM, \n", | |
" ENC_DROPOUT, \n", | |
" device)\n", | |
"\n", | |
"dec = Decoder(OUTPUT_DIM, \n", | |
" HID_DIM, \n", | |
" DEC_LAYERS, \n", | |
" DEC_HEADS, \n", | |
" DEC_PF_DIM, \n", | |
" DEC_DROPOUT, \n", | |
" device)\n", | |
"\n", | |
"SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]\n", | |
"TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "ty8cwecPy5qb", | |
"colab": {} | |
}, | |
"source": [ | |
"model = Seq2Seq(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "jLZkRVgwEqH3", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Model Params" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "loGDyC-ElpZY", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"outputId": "83f0b461-ade1-4aa1-90b7-e5a8456b3ff6" | |
}, | |
"source": [ | |
"def count_parameters(model):\n", | |
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", | |
"\n", | |
"print(f'The model has {count_parameters(model):,} trainable parameters')" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"The model has 12,506,137 trainable parameters\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "XcaxdCBJEqH7", | |
"colab_type": "code", | |
"colab": {}, | |
"outputId": "dda74e10-9654-46a5-a661-efdb38a99185" | |
}, | |
"source": [ | |
"model" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"Seq2Seq(\n", | |
" (encoder): Encoder(\n", | |
" (tok_embedding): Embedding(8021, 256)\n", | |
" (pos_embedding): Embedding(100, 256)\n", | |
" (layers): ModuleList(\n", | |
" (0): EncoderLayer(\n", | |
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", | |
" (ff_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", | |
" (self_attention): MultiHeadAttentionLayer(\n", | |
" (fc_q): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_k): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_v): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_o): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (positionwise_feedforward): PositionwiseFeedforwardLayer(\n", | |
" (fc_1): Linear(in_features=256, out_features=512, bias=True)\n", | |
" (fc_2): Linear(in_features=512, out_features=256, bias=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (1): EncoderLayer(\n", | |
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", | |
" (ff_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", | |
" (self_attention): MultiHeadAttentionLayer(\n", | |
" (fc_q): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_k): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_v): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_o): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (positionwise_feedforward): PositionwiseFeedforwardLayer(\n", | |
" (fc_1): Linear(in_features=256, out_features=512, bias=True)\n", | |
" (fc_2): Linear(in_features=512, out_features=256, bias=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (2): EncoderLayer(\n", | |
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", | |
" (ff_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", | |
" (self_attention): MultiHeadAttentionLayer(\n", | |
" (fc_q): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_k): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_v): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_o): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (positionwise_feedforward): PositionwiseFeedforwardLayer(\n", | |
" (fc_1): Linear(in_features=256, out_features=512, bias=True)\n", | |
" (fc_2): Linear(in_features=512, out_features=256, bias=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" )\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (decoder): Decoder(\n", | |
" (tok_embedding): Embedding(12569, 256)\n", | |
" (pos_embedding): Embedding(100, 256)\n", | |
" (layers): ModuleList(\n", | |
" (0): DecoderLayer(\n", | |
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", | |
" (enc_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", | |
" (ff_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", | |
" (self_attention): MultiHeadAttentionLayer(\n", | |
" (fc_q): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_k): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_v): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_o): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (encoder_attention): MultiHeadAttentionLayer(\n", | |
" (fc_q): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_k): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_v): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_o): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (positionwise_feedforward): PositionwiseFeedforwardLayer(\n", | |
" (fc_1): Linear(in_features=256, out_features=512, bias=True)\n", | |
" (fc_2): Linear(in_features=512, out_features=256, bias=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (1): DecoderLayer(\n", | |
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", | |
" (enc_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", | |
" (ff_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", | |
" (self_attention): MultiHeadAttentionLayer(\n", | |
" (fc_q): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_k): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_v): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_o): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (encoder_attention): MultiHeadAttentionLayer(\n", | |
" (fc_q): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_k): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_v): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_o): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (positionwise_feedforward): PositionwiseFeedforwardLayer(\n", | |
" (fc_1): Linear(in_features=256, out_features=512, bias=True)\n", | |
" (fc_2): Linear(in_features=512, out_features=256, bias=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (2): DecoderLayer(\n", | |
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", | |
" (enc_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", | |
" (ff_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", | |
" (self_attention): MultiHeadAttentionLayer(\n", | |
" (fc_q): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_k): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_v): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_o): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (encoder_attention): MultiHeadAttentionLayer(\n", | |
" (fc_q): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_k): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_v): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (fc_o): Linear(in_features=256, out_features=256, bias=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (positionwise_feedforward): PositionwiseFeedforwardLayer(\n", | |
" (fc_1): Linear(in_features=256, out_features=512, bias=True)\n", | |
" (fc_2): Linear(in_features=512, out_features=256, bias=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" )\n", | |
" (fc_out): Linear(in_features=256, out_features=12569, bias=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
")" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 23 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"scrolled": false, | |
"id": "0EjtrhPpEqH_", | |
"colab_type": "code", | |
"colab": {}, | |
"outputId": "c60e928f-4de8-4351-faef-7cc2cb981c82" | |
}, | |
"source": [ | |
"from prettytable import PrettyTable\n", | |
"\n", | |
"def count_parameters(model):\n", | |
" table = PrettyTable([\"Modules\", \"Parameters\"])\n", | |
" total_params = 0\n", | |
" for name, parameter in model.named_parameters():\n", | |
" if not parameter.requires_grad: continue\n", | |
" param = parameter.numel()\n", | |
" table.add_row([name, param])\n", | |
" total_params+=param\n", | |
" print(table)\n", | |
" print(f\"Total Trainable Params: {total_params}\")\n", | |
" return total_params\n", | |
" \n", | |
"count_parameters(model)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"+-------------------------------------------------------+------------+\n", | |
"| Modules | Parameters |\n", | |
"+-------------------------------------------------------+------------+\n", | |
"| encoder.tok_embedding.weight | 2053376 |\n", | |
"| encoder.pos_embedding.weight | 25600 |\n", | |
"| encoder.layers.0.self_attn_layer_norm.weight | 256 |\n", | |
"| encoder.layers.0.self_attn_layer_norm.bias | 256 |\n", | |
"| encoder.layers.0.ff_layer_norm.weight | 256 |\n", | |
"| encoder.layers.0.ff_layer_norm.bias | 256 |\n", | |
"| encoder.layers.0.self_attention.fc_q.weight | 65536 |\n", | |
"| encoder.layers.0.self_attention.fc_q.bias | 256 |\n", | |
"| encoder.layers.0.self_attention.fc_k.weight | 65536 |\n", | |
"| encoder.layers.0.self_attention.fc_k.bias | 256 |\n", | |
"| encoder.layers.0.self_attention.fc_v.weight | 65536 |\n", | |
"| encoder.layers.0.self_attention.fc_v.bias | 256 |\n", | |
"| encoder.layers.0.self_attention.fc_o.weight | 65536 |\n", | |
"| encoder.layers.0.self_attention.fc_o.bias | 256 |\n", | |
"| encoder.layers.0.positionwise_feedforward.fc_1.weight | 131072 |\n", | |
"| encoder.layers.0.positionwise_feedforward.fc_1.bias | 512 |\n", | |
"| encoder.layers.0.positionwise_feedforward.fc_2.weight | 131072 |\n", | |
"| encoder.layers.0.positionwise_feedforward.fc_2.bias | 256 |\n", | |
"| encoder.layers.1.self_attn_layer_norm.weight | 256 |\n", | |
"| encoder.layers.1.self_attn_layer_norm.bias | 256 |\n", | |
"| encoder.layers.1.ff_layer_norm.weight | 256 |\n", | |
"| encoder.layers.1.ff_layer_norm.bias | 256 |\n", | |
"| encoder.layers.1.self_attention.fc_q.weight | 65536 |\n", | |
"| encoder.layers.1.self_attention.fc_q.bias | 256 |\n", | |
"| encoder.layers.1.self_attention.fc_k.weight | 65536 |\n", | |
"| encoder.layers.1.self_attention.fc_k.bias | 256 |\n", | |
"| encoder.layers.1.self_attention.fc_v.weight | 65536 |\n", | |
"| encoder.layers.1.self_attention.fc_v.bias | 256 |\n", | |
"| encoder.layers.1.self_attention.fc_o.weight | 65536 |\n", | |
"| encoder.layers.1.self_attention.fc_o.bias | 256 |\n", | |
"| encoder.layers.1.positionwise_feedforward.fc_1.weight | 131072 |\n", | |
"| encoder.layers.1.positionwise_feedforward.fc_1.bias | 512 |\n", | |
"| encoder.layers.1.positionwise_feedforward.fc_2.weight | 131072 |\n", | |
"| encoder.layers.1.positionwise_feedforward.fc_2.bias | 256 |\n", | |
"| encoder.layers.2.self_attn_layer_norm.weight | 256 |\n", | |
"| encoder.layers.2.self_attn_layer_norm.bias | 256 |\n", | |
"| encoder.layers.2.ff_layer_norm.weight | 256 |\n", | |
"| encoder.layers.2.ff_layer_norm.bias | 256 |\n", | |
"| encoder.layers.2.self_attention.fc_q.weight | 65536 |\n", | |
"| encoder.layers.2.self_attention.fc_q.bias | 256 |\n", | |
"| encoder.layers.2.self_attention.fc_k.weight | 65536 |\n", | |
"| encoder.layers.2.self_attention.fc_k.bias | 256 |\n", | |
"| encoder.layers.2.self_attention.fc_v.weight | 65536 |\n", | |
"| encoder.layers.2.self_attention.fc_v.bias | 256 |\n", | |
"| encoder.layers.2.self_attention.fc_o.weight | 65536 |\n", | |
"| encoder.layers.2.self_attention.fc_o.bias | 256 |\n", | |
"| encoder.layers.2.positionwise_feedforward.fc_1.weight | 131072 |\n", | |
"| encoder.layers.2.positionwise_feedforward.fc_1.bias | 512 |\n", | |
"| encoder.layers.2.positionwise_feedforward.fc_2.weight | 131072 |\n", | |
"| encoder.layers.2.positionwise_feedforward.fc_2.bias | 256 |\n", | |
"| decoder.tok_embedding.weight | 3217664 |\n", | |
"| decoder.pos_embedding.weight | 25600 |\n", | |
"| decoder.layers.0.self_attn_layer_norm.weight | 256 |\n", | |
"| decoder.layers.0.self_attn_layer_norm.bias | 256 |\n", | |
"| decoder.layers.0.enc_attn_layer_norm.weight | 256 |\n", | |
"| decoder.layers.0.enc_attn_layer_norm.bias | 256 |\n", | |
"| decoder.layers.0.ff_layer_norm.weight | 256 |\n", | |
"| decoder.layers.0.ff_layer_norm.bias | 256 |\n", | |
"| decoder.layers.0.self_attention.fc_q.weight | 65536 |\n", | |
"| decoder.layers.0.self_attention.fc_q.bias | 256 |\n", | |
"| decoder.layers.0.self_attention.fc_k.weight | 65536 |\n", | |
"| decoder.layers.0.self_attention.fc_k.bias | 256 |\n", | |
"| decoder.layers.0.self_attention.fc_v.weight | 65536 |\n", | |
"| decoder.layers.0.self_attention.fc_v.bias | 256 |\n", | |
"| decoder.layers.0.self_attention.fc_o.weight | 65536 |\n", | |
"| decoder.layers.0.self_attention.fc_o.bias | 256 |\n", | |
"| decoder.layers.0.encoder_attention.fc_q.weight | 65536 |\n", | |
"| decoder.layers.0.encoder_attention.fc_q.bias | 256 |\n", | |
"| decoder.layers.0.encoder_attention.fc_k.weight | 65536 |\n", | |
"| decoder.layers.0.encoder_attention.fc_k.bias | 256 |\n", | |
"| decoder.layers.0.encoder_attention.fc_v.weight | 65536 |\n", | |
"| decoder.layers.0.encoder_attention.fc_v.bias | 256 |\n", | |
"| decoder.layers.0.encoder_attention.fc_o.weight | 65536 |\n", | |
"| decoder.layers.0.encoder_attention.fc_o.bias | 256 |\n", | |
"| decoder.layers.0.positionwise_feedforward.fc_1.weight | 131072 |\n", | |
"| decoder.layers.0.positionwise_feedforward.fc_1.bias | 512 |\n", | |
"| decoder.layers.0.positionwise_feedforward.fc_2.weight | 131072 |\n", | |
"| decoder.layers.0.positionwise_feedforward.fc_2.bias | 256 |\n", | |
"| decoder.layers.1.self_attn_layer_norm.weight | 256 |\n", | |
"| decoder.layers.1.self_attn_layer_norm.bias | 256 |\n", | |
"| decoder.layers.1.enc_attn_layer_norm.weight | 256 |\n", | |
"| decoder.layers.1.enc_attn_layer_norm.bias | 256 |\n", | |
"| decoder.layers.1.ff_layer_norm.weight | 256 |\n", | |
"| decoder.layers.1.ff_layer_norm.bias | 256 |\n", | |
"| decoder.layers.1.self_attention.fc_q.weight | 65536 |\n", | |
"| decoder.layers.1.self_attention.fc_q.bias | 256 |\n", | |
"| decoder.layers.1.self_attention.fc_k.weight | 65536 |\n", | |
"| decoder.layers.1.self_attention.fc_k.bias | 256 |\n", | |
"| decoder.layers.1.self_attention.fc_v.weight | 65536 |\n", | |
"| decoder.layers.1.self_attention.fc_v.bias | 256 |\n", | |
"| decoder.layers.1.self_attention.fc_o.weight | 65536 |\n", | |
"| decoder.layers.1.self_attention.fc_o.bias | 256 |\n", | |
"| decoder.layers.1.encoder_attention.fc_q.weight | 65536 |\n", | |
"| decoder.layers.1.encoder_attention.fc_q.bias | 256 |\n", | |
"| decoder.layers.1.encoder_attention.fc_k.weight | 65536 |\n", | |
"| decoder.layers.1.encoder_attention.fc_k.bias | 256 |\n", | |
"| decoder.layers.1.encoder_attention.fc_v.weight | 65536 |\n", | |
"| decoder.layers.1.encoder_attention.fc_v.bias | 256 |\n", | |
"| decoder.layers.1.encoder_attention.fc_o.weight | 65536 |\n", | |
"| decoder.layers.1.encoder_attention.fc_o.bias | 256 |\n", | |
"| decoder.layers.1.positionwise_feedforward.fc_1.weight | 131072 |\n", | |
"| decoder.layers.1.positionwise_feedforward.fc_1.bias | 512 |\n", | |
"| decoder.layers.1.positionwise_feedforward.fc_2.weight | 131072 |\n", | |
"| decoder.layers.1.positionwise_feedforward.fc_2.bias | 256 |\n", | |
"| decoder.layers.2.self_attn_layer_norm.weight | 256 |\n", | |
"| decoder.layers.2.self_attn_layer_norm.bias | 256 |\n", | |
"| decoder.layers.2.enc_attn_layer_norm.weight | 256 |\n", | |
"| decoder.layers.2.enc_attn_layer_norm.bias | 256 |\n", | |
"| decoder.layers.2.ff_layer_norm.weight | 256 |\n", | |
"| decoder.layers.2.ff_layer_norm.bias | 256 |\n", | |
"| decoder.layers.2.self_attention.fc_q.weight | 65536 |\n", | |
"| decoder.layers.2.self_attention.fc_q.bias | 256 |\n", | |
"| decoder.layers.2.self_attention.fc_k.weight | 65536 |\n", | |
"| decoder.layers.2.self_attention.fc_k.bias | 256 |\n", | |
"| decoder.layers.2.self_attention.fc_v.weight | 65536 |\n", | |
"| decoder.layers.2.self_attention.fc_v.bias | 256 |\n", | |
"| decoder.layers.2.self_attention.fc_o.weight | 65536 |\n", | |
"| decoder.layers.2.self_attention.fc_o.bias | 256 |\n", | |
"| decoder.layers.2.encoder_attention.fc_q.weight | 65536 |\n", | |
"| decoder.layers.2.encoder_attention.fc_q.bias | 256 |\n", | |
"| decoder.layers.2.encoder_attention.fc_k.weight | 65536 |\n", | |
"| decoder.layers.2.encoder_attention.fc_k.bias | 256 |\n", | |
"| decoder.layers.2.encoder_attention.fc_v.weight | 65536 |\n", | |
"| decoder.layers.2.encoder_attention.fc_v.bias | 256 |\n", | |
"| decoder.layers.2.encoder_attention.fc_o.weight | 65536 |\n", | |
"| decoder.layers.2.encoder_attention.fc_o.bias | 256 |\n", | |
"| decoder.layers.2.positionwise_feedforward.fc_1.weight | 131072 |\n", | |
"| decoder.layers.2.positionwise_feedforward.fc_1.bias | 512 |\n", | |
"| decoder.layers.2.positionwise_feedforward.fc_2.weight | 131072 |\n", | |
"| decoder.layers.2.positionwise_feedforward.fc_2.bias | 256 |\n", | |
"| decoder.fc_out.weight | 3217664 |\n", | |
"| decoder.fc_out.bias | 12569 |\n", | |
"+-------------------------------------------------------+------------+\n", | |
"Total Trainable Params: 12506137\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"12506137" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 24 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "e7rHYuTKEqIC", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Method 1 (Load Weights)\n", | |
"\n", | |
"The above Total Trainable Params are 12506137 but in google colab (runtime=None, device='cpu') Total Trainable Params: 12490234\n", | |
"\n", | |
"So unable to load model weights" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "07fMQajdlpUq", | |
"scrolled": false, | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"outputId": "ba892bac-1dc1-4fe2-955a-934ac5056385" | |
}, | |
"source": [ | |
"model.load_state_dict(torch.load(f\"{model_name}_2.pt\", map_location=device), strict=False)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "error", | |
"ename": "RuntimeError", | |
"evalue": "Error(s) in loading state_dict for Seq2Seq:\n\tsize mismatch for decoder.tok_embedding.weight: copying a param with shape torch.Size([12538, 256]) from checkpoint, the shape in current model is torch.Size([12569, 256]).\n\tsize mismatch for decoder.fc_out.weight: copying a param with shape torch.Size([12538, 256]) from checkpoint, the shape in current model is torch.Size([12569, 256]).\n\tsize mismatch for decoder.fc_out.bias: copying a param with shape torch.Size([12538]) from checkpoint, the shape in current model is torch.Size([12569]).", | |
"traceback": [ | |
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[1;32m<ipython-input-25-977503f88f12>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34mf\"{model_name}_2.pt\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstrict\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[1;32mc:\\anaconda\\envs\\lang_trans\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36mload_state_dict\u001b[1;34m(self, state_dict, strict)\u001b[0m\n\u001b[0;32m 845\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0merror_msgs\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m>\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 846\u001b[0m raise RuntimeError('Error(s) in loading state_dict for {}:\\n\\t{}'.format(\n\u001b[1;32m--> 847\u001b[1;33m self.__class__.__name__, \"\\n\\t\".join(error_msgs)))\n\u001b[0m\u001b[0;32m 848\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0m_IncompatibleKeys\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmissing_keys\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0munexpected_keys\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 849\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", | |
"\u001b[1;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for Seq2Seq:\n\tsize mismatch for decoder.tok_embedding.weight: copying a param with shape torch.Size([12538, 256]) from checkpoint, the shape in current model is torch.Size([12569, 256]).\n\tsize mismatch for decoder.fc_out.weight: copying a param with shape torch.Size([12538, 256]) from checkpoint, the shape in current model is torch.Size([12569, 256]).\n\tsize mismatch for decoder.fc_out.bias: copying a param with shape torch.Size([12538]) from checkpoint, the shape in current model is torch.Size([12569])." | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "AkNGWg2QEqIG", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Method 2 (Load full model)\n", | |
"\n", | |
"It is loading here but giving error in next cell, see the params changes here again" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "rJvLhc3Zy5qd", | |
"colab": {}, | |
"outputId": "9406c490-5347-4600-bd19-ce7d323e869e" | |
}, | |
"source": [ | |
"model = torch.load(f\"{model_name}.pth\", map_location=torch.device('cpu'))\n", | |
"\n", | |
"count_parameters(model)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"+-------------------------------------------------------+------------+\n", | |
"| Modules | Parameters |\n", | |
"+-------------------------------------------------------+------------+\n", | |
"| encoder.tok_embedding.weight | 1552128 |\n", | |
"| encoder.pos_embedding.weight | 25600 |\n", | |
"| encoder.layers.0.self_attn_layer_norm.weight | 256 |\n", | |
"| encoder.layers.0.self_attn_layer_norm.bias | 256 |\n", | |
"| encoder.layers.0.ff_layer_norm.weight | 256 |\n", | |
"| encoder.layers.0.ff_layer_norm.bias | 256 |\n", | |
"| encoder.layers.0.self_attention.fc_q.weight | 65536 |\n", | |
"| encoder.layers.0.self_attention.fc_q.bias | 256 |\n", | |
"| encoder.layers.0.self_attention.fc_k.weight | 65536 |\n", | |
"| encoder.layers.0.self_attention.fc_k.bias | 256 |\n", | |
"| encoder.layers.0.self_attention.fc_v.weight | 65536 |\n", | |
"| encoder.layers.0.self_attention.fc_v.bias | 256 |\n", | |
"| encoder.layers.0.self_attention.fc_o.weight | 65536 |\n", | |
"| encoder.layers.0.self_attention.fc_o.bias | 256 |\n", | |
"| encoder.layers.0.positionwise_feedforward.fc_1.weight | 131072 |\n", | |
"| encoder.layers.0.positionwise_feedforward.fc_1.bias | 512 |\n", | |
"| encoder.layers.0.positionwise_feedforward.fc_2.weight | 131072 |\n", | |
"| encoder.layers.0.positionwise_feedforward.fc_2.bias | 256 |\n", | |
"| encoder.layers.1.self_attn_layer_norm.weight | 256 |\n", | |
"| encoder.layers.1.self_attn_layer_norm.bias | 256 |\n", | |
"| encoder.layers.1.ff_layer_norm.weight | 256 |\n", | |
"| encoder.layers.1.ff_layer_norm.bias | 256 |\n", | |
"| encoder.layers.1.self_attention.fc_q.weight | 65536 |\n", | |
"| encoder.layers.1.self_attention.fc_q.bias | 256 |\n", | |
"| encoder.layers.1.self_attention.fc_k.weight | 65536 |\n", | |
"| encoder.layers.1.self_attention.fc_k.bias | 256 |\n", | |
"| encoder.layers.1.self_attention.fc_v.weight | 65536 |\n", | |
"| encoder.layers.1.self_attention.fc_v.bias | 256 |\n", | |
"| encoder.layers.1.self_attention.fc_o.weight | 65536 |\n", | |
"| encoder.layers.1.self_attention.fc_o.bias | 256 |\n", | |
"| encoder.layers.1.positionwise_feedforward.fc_1.weight | 131072 |\n", | |
"| encoder.layers.1.positionwise_feedforward.fc_1.bias | 512 |\n", | |
"| encoder.layers.1.positionwise_feedforward.fc_2.weight | 131072 |\n", | |
"| encoder.layers.1.positionwise_feedforward.fc_2.bias | 256 |\n", | |
"| encoder.layers.2.self_attn_layer_norm.weight | 256 |\n", | |
"| encoder.layers.2.self_attn_layer_norm.bias | 256 |\n", | |
"| encoder.layers.2.ff_layer_norm.weight | 256 |\n", | |
"| encoder.layers.2.ff_layer_norm.bias | 256 |\n", | |
"| encoder.layers.2.self_attention.fc_q.weight | 65536 |\n", | |
"| encoder.layers.2.self_attention.fc_q.bias | 256 |\n", | |
"| encoder.layers.2.self_attention.fc_k.weight | 65536 |\n", | |
"| encoder.layers.2.self_attention.fc_k.bias | 256 |\n", | |
"| encoder.layers.2.self_attention.fc_v.weight | 65536 |\n", | |
"| encoder.layers.2.self_attention.fc_v.bias | 256 |\n", | |
"| encoder.layers.2.self_attention.fc_o.weight | 65536 |\n", | |
"| encoder.layers.2.self_attention.fc_o.bias | 256 |\n", | |
"| encoder.layers.2.positionwise_feedforward.fc_1.weight | 131072 |\n", | |
"| encoder.layers.2.positionwise_feedforward.fc_1.bias | 512 |\n", | |
"| encoder.layers.2.positionwise_feedforward.fc_2.weight | 131072 |\n", | |
"| encoder.layers.2.positionwise_feedforward.fc_2.bias | 256 |\n", | |
"| decoder.tok_embedding.weight | 2380288 |\n", | |
"| decoder.pos_embedding.weight | 25600 |\n", | |
"| decoder.layers.0.self_attn_layer_norm.weight | 256 |\n", | |
"| decoder.layers.0.self_attn_layer_norm.bias | 256 |\n", | |
"| decoder.layers.0.enc_attn_layer_norm.weight | 256 |\n", | |
"| decoder.layers.0.enc_attn_layer_norm.bias | 256 |\n", | |
"| decoder.layers.0.ff_layer_norm.weight | 256 |\n", | |
"| decoder.layers.0.ff_layer_norm.bias | 256 |\n", | |
"| decoder.layers.0.self_attention.fc_q.weight | 65536 |\n", | |
"| decoder.layers.0.self_attention.fc_q.bias | 256 |\n", | |
"| decoder.layers.0.self_attention.fc_k.weight | 65536 |\n", | |
"| decoder.layers.0.self_attention.fc_k.bias | 256 |\n", | |
"| decoder.layers.0.self_attention.fc_v.weight | 65536 |\n", | |
"| decoder.layers.0.self_attention.fc_v.bias | 256 |\n", | |
"| decoder.layers.0.self_attention.fc_o.weight | 65536 |\n", | |
"| decoder.layers.0.self_attention.fc_o.bias | 256 |\n", | |
"| decoder.layers.0.encoder_attention.fc_q.weight | 65536 |\n", | |
"| decoder.layers.0.encoder_attention.fc_q.bias | 256 |\n", | |
"| decoder.layers.0.encoder_attention.fc_k.weight | 65536 |\n", | |
"| decoder.layers.0.encoder_attention.fc_k.bias | 256 |\n", | |
"| decoder.layers.0.encoder_attention.fc_v.weight | 65536 |\n", | |
"| decoder.layers.0.encoder_attention.fc_v.bias | 256 |\n", | |
"| decoder.layers.0.encoder_attention.fc_o.weight | 65536 |\n", | |
"| decoder.layers.0.encoder_attention.fc_o.bias | 256 |\n", | |
"| decoder.layers.0.positionwise_feedforward.fc_1.weight | 131072 |\n", | |
"| decoder.layers.0.positionwise_feedforward.fc_1.bias | 512 |\n", | |
"| decoder.layers.0.positionwise_feedforward.fc_2.weight | 131072 |\n", | |
"| decoder.layers.0.positionwise_feedforward.fc_2.bias | 256 |\n", | |
"| decoder.layers.1.self_attn_layer_norm.weight | 256 |\n", | |
"| decoder.layers.1.self_attn_layer_norm.bias | 256 |\n", | |
"| decoder.layers.1.enc_attn_layer_norm.weight | 256 |\n", | |
"| decoder.layers.1.enc_attn_layer_norm.bias | 256 |\n", | |
"| decoder.layers.1.ff_layer_norm.weight | 256 |\n", | |
"| decoder.layers.1.ff_layer_norm.bias | 256 |\n", | |
"| decoder.layers.1.self_attention.fc_q.weight | 65536 |\n", | |
"| decoder.layers.1.self_attention.fc_q.bias | 256 |\n", | |
"| decoder.layers.1.self_attention.fc_k.weight | 65536 |\n", | |
"| decoder.layers.1.self_attention.fc_k.bias | 256 |\n", | |
"| decoder.layers.1.self_attention.fc_v.weight | 65536 |\n", | |
"| decoder.layers.1.self_attention.fc_v.bias | 256 |\n", | |
"| decoder.layers.1.self_attention.fc_o.weight | 65536 |\n", | |
"| decoder.layers.1.self_attention.fc_o.bias | 256 |\n", | |
"| decoder.layers.1.encoder_attention.fc_q.weight | 65536 |\n", | |
"| decoder.layers.1.encoder_attention.fc_q.bias | 256 |\n", | |
"| decoder.layers.1.encoder_attention.fc_k.weight | 65536 |\n", | |
"| decoder.layers.1.encoder_attention.fc_k.bias | 256 |\n", | |
"| decoder.layers.1.encoder_attention.fc_v.weight | 65536 |\n", | |
"| decoder.layers.1.encoder_attention.fc_v.bias | 256 |\n", | |
"| decoder.layers.1.encoder_attention.fc_o.weight | 65536 |\n", | |
"| decoder.layers.1.encoder_attention.fc_o.bias | 256 |\n", | |
"| decoder.layers.1.positionwise_feedforward.fc_1.weight | 131072 |\n", | |
"| decoder.layers.1.positionwise_feedforward.fc_1.bias | 512 |\n", | |
"| decoder.layers.1.positionwise_feedforward.fc_2.weight | 131072 |\n", | |
"| decoder.layers.1.positionwise_feedforward.fc_2.bias | 256 |\n", | |
"| decoder.layers.2.self_attn_layer_norm.weight | 256 |\n", | |
"| decoder.layers.2.self_attn_layer_norm.bias | 256 |\n", | |
"| decoder.layers.2.enc_attn_layer_norm.weight | 256 |\n", | |
"| decoder.layers.2.enc_attn_layer_norm.bias | 256 |\n", | |
"| decoder.layers.2.ff_layer_norm.weight | 256 |\n", | |
"| decoder.layers.2.ff_layer_norm.bias | 256 |\n", | |
"| decoder.layers.2.self_attention.fc_q.weight | 65536 |\n", | |
"| decoder.layers.2.self_attention.fc_q.bias | 256 |\n", | |
"| decoder.layers.2.self_attention.fc_k.weight | 65536 |\n", | |
"| decoder.layers.2.self_attention.fc_k.bias | 256 |\n", | |
"| decoder.layers.2.self_attention.fc_v.weight | 65536 |\n", | |
"| decoder.layers.2.self_attention.fc_v.bias | 256 |\n", | |
"| decoder.layers.2.self_attention.fc_o.weight | 65536 |\n", | |
"| decoder.layers.2.self_attention.fc_o.bias | 256 |\n", | |
"| decoder.layers.2.encoder_attention.fc_q.weight | 65536 |\n", | |
"| decoder.layers.2.encoder_attention.fc_q.bias | 256 |\n", | |
"| decoder.layers.2.encoder_attention.fc_k.weight | 65536 |\n", | |
"| decoder.layers.2.encoder_attention.fc_k.bias | 256 |\n", | |
"| decoder.layers.2.encoder_attention.fc_v.weight | 65536 |\n", | |
"| decoder.layers.2.encoder_attention.fc_v.bias | 256 |\n", | |
"| decoder.layers.2.encoder_attention.fc_o.weight | 65536 |\n", | |
"| decoder.layers.2.encoder_attention.fc_o.bias | 256 |\n", | |
"| decoder.layers.2.positionwise_feedforward.fc_1.weight | 131072 |\n", | |
"| decoder.layers.2.positionwise_feedforward.fc_1.bias | 512 |\n", | |
"| decoder.layers.2.positionwise_feedforward.fc_2.weight | 131072 |\n", | |
"| decoder.layers.2.positionwise_feedforward.fc_2.bias | 256 |\n", | |
"| decoder.fc_out.weight | 2380288 |\n", | |
"| decoder.fc_out.bias | 9298 |\n", | |
"+-------------------------------------------------------+------------+\n", | |
"Total Trainable Params: 10326866\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"10326866" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 29 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"colab_type": "text", | |
"id": "QskLHNo5y5qh" | |
}, | |
"source": [ | |
"# Prediction" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "A_7Lj1iwy5qh", | |
"colab": {} | |
}, | |
"source": [ | |
"def translate_sentence(sentence, src_field, trg_field, model, device, max_len = 50):\n", | |
" \n", | |
" model.eval()\n", | |
" \n", | |
" if isinstance(sentence, str):\n", | |
" nlp = spacy.load(lang2_name[1:])\n", | |
" tokens = [token.text.lower() for token in nlp(sentence)]\n", | |
" else:\n", | |
" tokens = [token.lower() for token in sentence]\n", | |
"\n", | |
" tokens = [src_field.init_token] + tokens + [src_field.eos_token]\n", | |
"\n", | |
" print(tokens)\n", | |
" \n", | |
" src_indexes = [src_field.vocab.stoi[token] for token in tokens]\n", | |
"\n", | |
" src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)\n", | |
" \n", | |
" src_mask = model.make_src_mask(src_tensor)\n", | |
" \n", | |
" with torch.no_grad():\n", | |
" enc_src = model.encoder(src_tensor, src_mask)\n", | |
"\n", | |
" trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]\n", | |
"\n", | |
" for i in range(max_len):\n", | |
"\n", | |
" trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)\n", | |
"\n", | |
" trg_mask = model.make_trg_mask(trg_tensor)\n", | |
" \n", | |
" with torch.no_grad():\n", | |
" output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)\n", | |
" \n", | |
" pred_token = output.argmax(2)[:,-1].item()\n", | |
" \n", | |
" trg_indexes.append(pred_token)\n", | |
"\n", | |
" if pred_token == trg_field.vocab.stoi[trg_field.eos_token]:\n", | |
" break\n", | |
" \n", | |
" trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes]\n", | |
" \n", | |
" return trg_tokens[1:], attention" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "MZDVMsewy5qj", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 52 | |
}, | |
"outputId": "e668cc00-850a-4461-ca88-4f3d04aa5ef3" | |
}, | |
"source": [ | |
"src = \"AlwaysOn is an app for every device with an AMOLED or OLED display\"\n", | |
"\n", | |
"src = src.split(\" \")\n", | |
"\n", | |
"translation, attention = translate_sentence(src, SRC, TRG, model, device)\n", | |
"\n", | |
"print(f'predicted trg = {translation}')" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"['<sos>', 'alwayson', 'is', 'an', 'app', 'for', 'every', 'device', 'with', 'an', 'amoled', 'or', 'oled', 'display', '<eos>']\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "error", | |
"ename": "AssertionError", | |
"evalue": "Torch not compiled with CUDA enabled", | |
"traceback": [ | |
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[1;31mAssertionError\u001b[0m Traceback (most recent call last)", | |
"\u001b[1;32m<ipython-input-31-ff8428ee5832>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0msrc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msrc\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\" \"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 5\u001b[1;33m \u001b[0mtranslation\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mattention\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtranslate_sentence\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msrc\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mSRC\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mTRG\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 6\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 7\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34mf'predicted trg = {translation}'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", | |
"\u001b[1;32m<ipython-input-30-5d12c49b1bf1>\u001b[0m in \u001b[0;36mtranslate_sentence\u001b[1;34m(sentence, src_field, trg_field, model, device, max_len)\u001b[0m\n\u001b[0;32m 20\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 21\u001b[0m \u001b[1;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 22\u001b[1;33m \u001b[0menc_src\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mencoder\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msrc_tensor\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msrc_mask\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 23\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 24\u001b[0m \u001b[0mtrg_indexes\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mtrg_field\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvocab\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstoi\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mtrg_field\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minit_token\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", | |
"\u001b[1;32mc:\\anaconda\\envs\\lang_trans\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 548\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 549\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 550\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 551\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 552\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", | |
"\u001b[1;32m<ipython-input-19-5ecaa062b686>\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, src, src_mask)\u001b[0m\n\u001b[0;32m 37\u001b[0m \u001b[0mdevice\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'cuda'\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_available\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32melse\u001b[0m \u001b[1;34m'cpu'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 38\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 39\u001b[1;33m \u001b[0mpos\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msrc_len\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrepeat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 40\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 41\u001b[0m \u001b[1;31m#pos = [batch size, src len]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", | |
"\u001b[1;32mc:\\anaconda\\envs\\lang_trans\\lib\\site-packages\\torch\\cuda\\__init__.py\u001b[0m in \u001b[0;36m_lazy_init\u001b[1;34m()\u001b[0m\n\u001b[0;32m 147\u001b[0m raise RuntimeError(\n\u001b[0;32m 148\u001b[0m \"Cannot re-initialize CUDA in forked subprocess. \" + msg)\n\u001b[1;32m--> 149\u001b[1;33m \u001b[0m_check_driver\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 150\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0m_cudart\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 151\u001b[0m raise AssertionError(\n", | |
"\u001b[1;32mc:\\anaconda\\envs\\lang_trans\\lib\\site-packages\\torch\\cuda\\__init__.py\u001b[0m in \u001b[0;36m_check_driver\u001b[1;34m()\u001b[0m\n\u001b[0;32m 45\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_check_driver\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 46\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'_cuda_isDriverSufficient'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 47\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mAssertionError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Torch not compiled with CUDA enabled\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 48\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_cuda_isDriverSufficient\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 49\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_cuda_getDriverVersion\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", | |
"\u001b[1;31mAssertionError\u001b[0m: Torch not compiled with CUDA enabled" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"colab_type": "text", | |
"id": "CtlSMBWi6FDS" | |
}, | |
"source": [ | |
"# Visualization" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "Tn5Eo_-98ccc", | |
"colab": {} | |
}, | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"import matplotlib.ticker as ticker" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "sZDvG431-6cQ", | |
"colab": {} | |
}, | |
"source": [ | |
"li = ['binary', 'cool']" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "tdiewTRz6i68", | |
"colab": {} | |
}, | |
"source": [ | |
"def display_attention(sentence, translation, attention, n_heads = 8, n_rows = 4, n_cols = 2, clr=False):\n", | |
" \n", | |
" if clr == True:\n", | |
" for clr_code in li:\n", | |
" assert n_rows * n_cols == n_heads\n", | |
" \n", | |
" fig = plt.figure(figsize=(15,25))\n", | |
" \n", | |
" for i in range(n_heads):\n", | |
" \n", | |
" ax = fig.add_subplot(n_rows, n_cols, i+1)\n", | |
" \n", | |
" _attention = attention.squeeze(0)[i].cpu().detach().numpy()\n", | |
"\n", | |
"\n", | |
" cax = ax.matshow(_attention, cmap=clr_code)\n", | |
"\n", | |
" ax.tick_params(labelsize=12)\n", | |
" ax.set_xticklabels(['']+['<sos>']+[t.lower() for t in sentence]+['<eos>'], \n", | |
" rotation=45)\n", | |
" ax.set_yticklabels(['']+translation)\n", | |
"\n", | |
" ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n", | |
" ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n", | |
"\n", | |
" plt.show()\n", | |
" plt.close()\n", | |
" else:\n", | |
" assert n_rows * n_cols == n_heads\n", | |
" \n", | |
" fig = plt.figure(figsize=(10,20))\n", | |
" \n", | |
" for i in range(n_heads):\n", | |
" \n", | |
" ax = fig.add_subplot(n_rows, n_cols, i+1)\n", | |
" \n", | |
" _attention = attention.squeeze(0)[i].cpu().detach().numpy()\n", | |
"\n", | |
"\n", | |
" cax = ax.matshow(_attention, cmap='gist_heat')\n", | |
"\n", | |
" ax.tick_params(labelsize=12)\n", | |
" ax.set_xticklabels(['']+['<sos>']+[t.lower() for t in sentence]+['<eos>'], \n", | |
" rotation=45)\n", | |
" ax.set_yticklabels(['']+translation)\n", | |
"\n", | |
" ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n", | |
" ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n", | |
"\n", | |
" plt.show()\n", | |
" plt.close()" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "0WzIjPxr7E3v", | |
"colab": {} | |
}, | |
"source": [ | |
"display_attention(src, translation, attention, clr=True)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "sep9LyC18XYF", | |
"colab": {} | |
}, | |
"source": [ | |
"example_idx = 6\n", | |
"\n", | |
"src = vars(valid_data.examples[example_idx])['src']\n", | |
"trg = vars(valid_data.examples[example_idx])['trg']\n", | |
"\n", | |
"print(f'src = {src}')\n", | |
"print(f'trg = {trg}')" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "THcyVb-f8yfS", | |
"colab": {} | |
}, | |
"source": [ | |
"translation, attention = translate_sentence(src, SRC, TRG, model, device)\n", | |
"\n", | |
"print(f'predicted trg = {translation}')" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "4msONCaR81on", | |
"colab": {} | |
}, | |
"source": [ | |
"display_attention(src, translation, attention, clr=True)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "YS5XjAiE84RM", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment