Skip to content

Instantly share code, notes, and snippets.

@Dipeshpal
Created July 22, 2020 05:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Dipeshpal/5d55e7016c5f4fcda4e5b5c459f8436d to your computer and use it in GitHub Desktop.
Save Dipeshpal/5d55e7016c5f4fcda4e5b5c459f8436d to your computer and use it in GitHub Desktop.
prediction (1).ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "prediction (1).ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"hide_input": false,
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/Dipeshpal/5d55e7016c5f4fcda4e5b5c459f8436d/prediction-1.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "weuU8H1Dy-y7",
"colab": {}
},
"source": [
"# from google.colab import drive\n",
"# drive.mount('/content/drive')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "8gSy_4TBzJOT",
"colab": {}
},
"source": [
"import os\n",
"\n",
"# os.chdir(\"/content/drive/My Drive/Projects/lang_translator_pytorch_attention\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "Eu-N7bKKy5pq",
"colab": {}
},
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"\n",
"import torchtext\n",
"from torchtext import data\n",
"# from torchtext.datasets import Multi30k\n",
"from torchtext.data import Field, BucketIterator\n",
"\n",
"# import matplotlib.pyplot as plt\n",
"# import matplotlib.ticker as ticker\n",
"\n",
"import spacy\n",
"import numpy as np\n",
"\n",
"import random\n",
"import math\n",
"import time\n",
"import os"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "C-ejYsQay5pt",
"colab": {}
},
"source": [
"SEED = 1234\n",
"\n",
"random.seed(SEED)\n",
"np.random.seed(SEED)\n",
"torch.manual_seed(SEED)\n",
"torch.cuda.manual_seed(SEED)\n",
"torch.backends.cudnn.deterministic = True"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "TGMzD2lXEqGo",
"colab_type": "text"
},
"source": [
"# Basic setting up things"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "evhaQnH2y5pw",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 159
},
"outputId": "7199c0c8-4aaa-4db6-8cba-85f762f92c97"
},
"source": [
"BASE = \"dataset/\"\n",
"\n",
"folder_name = input(\"Enter folder name (Ex: en-es | None for en-de): \")\n",
"\n",
"MY_PATH = f\"{BASE}{folder_name}\"\n",
"\n",
"if MY_PATH == None:\n",
" lang1_name = \".en\"\n",
" lang2_name = \".de\"\n",
"else:\n",
" lang1_name = input(\"Enter Lang1 Name (Ex: en): \")\n",
" lang1_name = \".\"+lang1_name\n",
" lang2_name = input(\"Enter Lang2 Name (Ex: es): \")\n",
" lang2_name = \".\"+lang2_name\n",
" \n",
"BATCH_SIZE = int(input(\"Batch Size: (128 Recommended): \"))\n",
"N_EPOCHS = int(input(\"Number of Epochs (50 Recommended): \"))\n",
"\n",
"tensor_print = input(\"Do you want to print the tensor summary (Type yes or No): \")\n",
"summary = input(\"Do you want to print model summary (Type yes or No): \")\n",
"\n",
"model_name = f\"models/{lang1_name[1:]}-{lang2_name[1:]}/\"+lang1_name[1:]+'-'+lang2_name[1:]\n",
"print(\"Model Name: \", model_name)\n",
"# print(f\"Language 1: {lang1_name}, Language 2: {lang2_name}\")"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Enter folder name (Ex: en-es | None for en-de): en-fr\n",
"Enter Lang1 Name (Ex: en): en\n",
"Enter Lang2 Name (Ex: es): fr\n",
"Batch Size: (128 Recommended): 128\n",
"Number of Epochs (50 Recommended): 1\n",
"Do you want to print the tensor summary (Type yes or No): yes\n",
"Do you want to print model summary (Type yes or No): yes\n",
"Model Name: models/en-fr/en-fr\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "KulkumpFy5pz",
"colab": {}
},
"source": [
"def create_dirs(dir_path):\n",
" try:\n",
" os.mkdir(dir_path)\n",
" except:\n",
" pass\n",
"\n",
"create_dirs(\"models/\")\n",
"create_dirs(f\"models/{lang1_name[1:]}-{lang2_name[1:]}/\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "o-Uh0sxSy5p1",
"colab": {}
},
"source": [
"if lang2_name == \".hi\":\n",
" print(\"Hindi Tokenizater Installing\")\n",
" os.system(\"pip install inltk\")\n",
" from inltk.inltk import setup\n",
" from inltk.inltk import tokenize\n",
" setup('hi')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "SgynozXnzUIr",
"colab": {}
},
"source": [
"# !python -m spacy download en\n",
"# !python -m spacy download fr"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "e5qRUlEny5p3",
"colab": {}
},
"source": [
"# https://spacy.io/usage/models#languages\n",
"spacy_lang1 = spacy.load(lang1_name[1:])\n",
"\n",
"# if lang2_name != \".hi\":\n",
"# print(f\"Second Language is {lang2_name[1:]}\")\n",
"# # os.system(f\"python -m spacy download {lang2_name[1:]}\")\n",
"spacy_lang2 = spacy.load(lang2_name[1:])"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "-K5P4zvQy5p5",
"colab": {}
},
"source": [
"def tokenize_lang1(text):\n",
" return [tok.text for tok in spacy_lang1.tokenizer(text)]\n",
"\n",
"def tokenize_lang2(text):\n",
" if lang2_name == \".hi\":\n",
" return tokenize(text, \"hi\")\n",
" else: \n",
" return [tok.text for tok in spacy_lang2.tokenizer(text)]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "VVtY7iiNEqHG",
"colab_type": "text"
},
"source": [
"# Just checking"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "2i8GsA4Fy5p9",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
},
"outputId": "594d9cc4-3af6-4471-9466-c8894bce1d10"
},
"source": [
"if lang2_name == \".hi\":\n",
" print(\"Your language is Hindi, Just checking Tokenization\")\n",
" print(tokenize_lang2(\"प्राचीन काल में विक्रमादित्य नाम के एक आदर्श राजा हुआ करते थे।\"))\n",
"else:\n",
" print(\"Your primary language is English, Just checking Tokenization\")\n",
" print(tokenize_lang2(\"You language is English, Just checking Tokenization\"))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Your primary language is English, Just checking Tokenization\n",
"['You', 'language', 'is', 'English', ',', 'Just', 'checking', 'Tokenization']\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "7H1WRW0Fy5qA",
"colab": {}
},
"source": [
"SRC = Field(tokenize = tokenize_lang1, \n",
" init_token = '<sos>', \n",
" eos_token = '<eos>', \n",
" lower = True, \n",
" batch_first = True)\n",
"\n",
"TRG = Field(tokenize = tokenize_lang2, \n",
" init_token = '<sos>', \n",
" eos_token = '<eos>', \n",
" lower = True, \n",
" batch_first = True)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "qfZ-pcGAy5qD"
},
"source": [
"# Multi30 class with some customization"
]
},
{
"cell_type": "code",
"metadata": {
"code_folding": [
7,
51
],
"colab_type": "code",
"id": "xJDOXLyZy5qD",
"colab": {}
},
"source": [
"import os\n",
"import xml.etree.ElementTree as ET\n",
"import glob\n",
"import io\n",
"import codecs\n",
"\n",
"\n",
"class TranslationDataset(data.Dataset):\n",
" \n",
" @staticmethod\n",
" def sort_key(ex):\n",
" return data.interleave_keys(len(ex.src), len(ex.trg))\n",
"\n",
" def __init__(self, path, exts, fields, **kwargs):\n",
" \n",
" if not isinstance(fields[0], (tuple, list)):\n",
" fields = [('src', fields[0]), ('trg', fields[1])]\n",
"\n",
" src_path, trg_path = tuple(os.path.expanduser(path + x) for x in exts)\n",
"\n",
" examples = []\n",
" with io.open(src_path, mode='r', encoding='utf-8') as src_file, \\\n",
" io.open(trg_path, mode='r', encoding='utf-8') as trg_file:\n",
" for src_line, trg_line in zip(src_file, trg_file):\n",
" src_line, trg_line = src_line.strip(), trg_line.strip()\n",
" if src_line != '' and trg_line != '':\n",
" examples.append(data.Example.fromlist(\n",
" [src_line, trg_line], fields))\n",
"\n",
" super(TranslationDataset, self).__init__(examples, fields, **kwargs)\n",
"\n",
"\n",
" @classmethod\n",
" def splits(cls, exts, fields, path=None, root='dataset',\n",
" train='train', validation='val', test='test', **kwargs):\n",
" \n",
" if path is None or path is \"en-de\":\n",
" path = cls.download(root)\n",
"\n",
" train_data = None if train is None else cls(\n",
" os.path.join(path, train), exts, fields, **kwargs)\n",
" val_data = None if validation is None else cls(\n",
" os.path.join(path, validation), exts, fields, **kwargs)\n",
" test_data = None if test is None else cls(\n",
" os.path.join(path, test), exts, fields, **kwargs)\n",
" return tuple(d for d in (train_data, val_data, test_data)\n",
" if d is not None)\n",
" \n",
" \n",
"\n",
"\n",
"class CustomMulti30(TranslationDataset):\n",
" \n",
" urls = ['http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz',\n",
" 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz',\n",
" 'http://www.quest.dcs.shef.ac.uk/'\n",
" 'wmt17_files_mmt/mmt_task1_test2016.tar.gz']\n",
" \n",
" \n",
" name = \"en-de\"\n",
" dirname = ''\n",
"\n",
" @classmethod\n",
" def splits(cls, exts, fields, root='dataset',\n",
" train='train', validation='val', test='test', **kwargs):\n",
"\n",
" if 'path' not in kwargs:\n",
" expected_folder = os.path.join(root, cls.name)\n",
" path = expected_folder if os.path.exists(expected_folder) else None\n",
" else:\n",
" path = kwargs['path']\n",
" del kwargs['path']\n",
" \n",
" if path == None:\n",
" train = 'train'\n",
" test = 'test2016'\n",
" val = 'val'\n",
" exts = ('.en', '.de')\n",
"\n",
" return super(CustomMulti30, cls).splits(\n",
" exts, fields, path, root, train, validation, test, **kwargs)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "OmioHrSBy5qF",
"colab": {}
},
"source": [
"train_data, valid_data, test_data = CustomMulti30.splits(exts = (lang1_name, lang2_name), fields = (SRC, TRG), root='dataset', path=MY_PATH,\n",
" train='train', validation='val', test='test')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "0CVjpeivy5qI",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "382f97c5-453b-40a9-8654-3b3a26a8bcdc"
},
"source": [
"len(train_data), len(valid_data), len(test_data)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(119144, 38297, 12766)"
]
},
"metadata": {
"tags": []
},
"execution_count": 14
}
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "r43EOIqCy5qL",
"colab": {}
},
"source": [
"SRC.build_vocab(train_data, min_freq = 2)\n",
"TRG.build_vocab(train_data, min_freq = 2)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "Dz97f-sfy5qO",
"colab": {}
},
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "XJRv8lwHy5qQ",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "ce9b17c1-a2bd-4416-dbdd-95306ed8a4c2"
},
"source": [
"device"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"device(type='cpu')"
]
},
"metadata": {
"tags": []
},
"execution_count": 17
}
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "bMUcegoTy5qS",
"colab": {}
},
"source": [
"train_iterator, valid_iterator, test_iterator = BucketIterator.splits(\n",
" (train_data, valid_data, test_data), \n",
" batch_size = BATCH_SIZE,\n",
" device = device)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "fpgbgauey5qW"
},
"source": [
"# Model"
]
},
{
"cell_type": "code",
"metadata": {
"code_folding": [
0,
53,
95,
170,
194,
255,
309
],
"colab_type": "code",
"id": "3ZNp_tqqy5qW",
"colab": {}
},
"source": [
"class Encoder(nn.Module):\n",
" def __init__(self, \n",
" input_dim, \n",
" hid_dim, \n",
" n_layers, \n",
" n_heads, \n",
" pf_dim,\n",
" dropout, \n",
" device,\n",
" max_length = 100):\n",
" super().__init__()\n",
"\n",
" self.device = device\n",
" \n",
" self.tok_embedding = nn.Embedding(input_dim, hid_dim)\n",
" self.pos_embedding = nn.Embedding(max_length, hid_dim)\n",
" \n",
" self.layers = nn.ModuleList([EncoderLayer(hid_dim, \n",
" n_heads, \n",
" pf_dim,\n",
" dropout, \n",
" device) \n",
" for _ in range(n_layers)])\n",
" \n",
" self.dropout = nn.Dropout(dropout)\n",
" \n",
" self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)\n",
" \n",
" def forward(self, src, src_mask):\n",
" \n",
" #src = [batch size, src len]\n",
" #src_mask = [batch size, src len]\n",
" \n",
" batch_size = src.shape[0]\n",
" src_len = src.shape[1]\n",
" \n",
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
" pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)\n",
" \n",
" #pos = [batch size, src len]\n",
" \n",
" src = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos))\n",
" \n",
" #src = [batch size, src len, hid dim]\n",
" \n",
" for layer in self.layers:\n",
" src = layer(src, src_mask)\n",
" \n",
" #src = [batch size, src len, hid dim]\n",
" \n",
" return src\n",
" \n",
" \n",
" \n",
"class EncoderLayer(nn.Module):\n",
" def __init__(self, \n",
" hid_dim, \n",
" n_heads, \n",
" pf_dim, \n",
" dropout, \n",
" device):\n",
" super().__init__()\n",
" \n",
" self.self_attn_layer_norm = nn.LayerNorm(hid_dim)\n",
" self.ff_layer_norm = nn.LayerNorm(hid_dim)\n",
" self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)\n",
" self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, \n",
" pf_dim, \n",
" dropout)\n",
" self.dropout = nn.Dropout(dropout)\n",
" \n",
" def forward(self, src, src_mask):\n",
" \n",
" #src = [batch size, src len, hid dim]\n",
" #src_mask = [batch size, src len]\n",
" \n",
" #self attention\n",
" _src, _ = self.self_attention(src, src, src, src_mask)\n",
" \n",
" #dropout, residual connection and layer norm\n",
" src = self.self_attn_layer_norm(src + self.dropout(_src))\n",
" \n",
" #src = [batch size, src len, hid dim]\n",
" \n",
" #positionwise feedforward\n",
" _src = self.positionwise_feedforward(src)\n",
" \n",
" #dropout, residual and layer norm\n",
" src = self.ff_layer_norm(src + self.dropout(_src))\n",
" \n",
" #src = [batch size, src len, hid dim]\n",
" \n",
" return src\n",
" \n",
" \n",
" \n",
"class MultiHeadAttentionLayer(nn.Module):\n",
" def __init__(self, hid_dim, n_heads, dropout, device):\n",
" super().__init__()\n",
" \n",
" assert hid_dim % n_heads == 0\n",
" \n",
" self.hid_dim = hid_dim\n",
" self.n_heads = n_heads\n",
" self.head_dim = hid_dim // n_heads\n",
" \n",
" self.fc_q = nn.Linear(hid_dim, hid_dim)\n",
" self.fc_k = nn.Linear(hid_dim, hid_dim)\n",
" self.fc_v = nn.Linear(hid_dim, hid_dim)\n",
" \n",
" self.fc_o = nn.Linear(hid_dim, hid_dim)\n",
" \n",
" self.dropout = nn.Dropout(dropout)\n",
" \n",
" self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)\n",
" \n",
" def forward(self, query, key, value, mask = None):\n",
" \n",
" batch_size = query.shape[0]\n",
" \n",
" #query = [batch size, query len, hid dim]\n",
" #key = [batch size, key len, hid dim]\n",
" #value = [batch size, value len, hid dim]\n",
" \n",
" Q = self.fc_q(query)\n",
" K = self.fc_k(key)\n",
" V = self.fc_v(value)\n",
" \n",
" #Q = [batch size, query len, hid dim]\n",
" #K = [batch size, key len, hid dim]\n",
" #V = [batch size, value len, hid dim]\n",
" \n",
" Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)\n",
" K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)\n",
" V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)\n",
" \n",
" #Q = [batch size, n heads, query len, head dim]\n",
" #K = [batch size, n heads, key len, head dim]\n",
" #V = [batch size, n heads, value len, head dim]\n",
" \n",
" energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale\n",
" \n",
" #energy = [batch size, n heads, query len, key len]\n",
" \n",
" if mask is not None:\n",
" energy = energy.masked_fill(mask == 0, -1e10)\n",
" \n",
" attention = torch.softmax(energy, dim = -1)\n",
" \n",
" #attention = [batch size, n heads, query len, key len]\n",
" \n",
" x = torch.matmul(self.dropout(attention), V)\n",
" \n",
" #x = [batch size, n heads, query len, head dim]\n",
" \n",
" x = x.permute(0, 2, 1, 3).contiguous()\n",
" \n",
" #x = [batch size, query len, n heads, head dim]\n",
" \n",
" x = x.view(batch_size, -1, self.hid_dim)\n",
" \n",
" #x = [batch size, query len, hid dim]\n",
" \n",
" x = self.fc_o(x)\n",
" \n",
" #x = [batch size, query len, hid dim]\n",
" \n",
" return x, attention\n",
" \n",
" \n",
" \n",
"class PositionwiseFeedforwardLayer(nn.Module):\n",
" def __init__(self, hid_dim, pf_dim, dropout):\n",
" super().__init__()\n",
" \n",
" self.fc_1 = nn.Linear(hid_dim, pf_dim)\n",
" self.fc_2 = nn.Linear(pf_dim, hid_dim)\n",
" \n",
" self.dropout = nn.Dropout(dropout)\n",
" \n",
" def forward(self, x):\n",
" \n",
" #x = [batch size, seq len, hid dim]\n",
" \n",
" x = self.dropout(torch.relu(self.fc_1(x)))\n",
" \n",
" #x = [batch size, seq len, pf dim]\n",
" \n",
" x = self.fc_2(x)\n",
" \n",
" #x = [batch size, seq len, hid dim]\n",
" \n",
" return x\n",
" \n",
" \n",
"class Decoder(nn.Module):\n",
" def __init__(self, \n",
" output_dim, \n",
" hid_dim, \n",
" n_layers, \n",
" n_heads, \n",
" pf_dim, \n",
" dropout, \n",
" device,\n",
" max_length = 100):\n",
" super().__init__()\n",
" \n",
" self.device = device\n",
" \n",
" self.tok_embedding = nn.Embedding(output_dim, hid_dim)\n",
" self.pos_embedding = nn.Embedding(max_length, hid_dim)\n",
" \n",
" self.layers = nn.ModuleList([DecoderLayer(hid_dim, \n",
" n_heads, \n",
" pf_dim, \n",
" dropout, \n",
" device)\n",
" for _ in range(n_layers)])\n",
" \n",
" self.fc_out = nn.Linear(hid_dim, output_dim)\n",
" \n",
" self.dropout = nn.Dropout(dropout)\n",
" \n",
" self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)\n",
" \n",
" def forward(self, trg, enc_src, trg_mask, src_mask):\n",
" \n",
" #trg = [batch size, trg len]\n",
" #enc_src = [batch size, src len, hid dim]\n",
" #trg_mask = [batch size, trg len]\n",
" #src_mask = [batch size, src len]\n",
" \n",
" batch_size = trg.shape[0]\n",
" trg_len = trg.shape[1]\n",
" \n",
" pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)\n",
" \n",
" #pos = [batch size, trg len]\n",
" \n",
" trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))\n",
" \n",
" #trg = [batch size, trg len, hid dim]\n",
" \n",
" for layer in self.layers:\n",
" trg, attention = layer(trg, enc_src, trg_mask, src_mask)\n",
" \n",
" #trg = [batch size, trg len, hid dim]\n",
" #attention = [batch size, n heads, trg len, src len]\n",
" \n",
" output = self.fc_out(trg)\n",
" \n",
" #output = [batch size, trg len, output dim]\n",
" \n",
" return output, attention\n",
" \n",
" \n",
"class DecoderLayer(nn.Module):\n",
" def __init__(self, \n",
" hid_dim, \n",
" n_heads, \n",
" pf_dim, \n",
" dropout, \n",
" device):\n",
" super().__init__()\n",
" \n",
" self.self_attn_layer_norm = nn.LayerNorm(hid_dim)\n",
" self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)\n",
" self.ff_layer_norm = nn.LayerNorm(hid_dim)\n",
" self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)\n",
" self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)\n",
" self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, \n",
" pf_dim, \n",
" dropout)\n",
" self.dropout = nn.Dropout(dropout)\n",
" \n",
" def forward(self, trg, enc_src, trg_mask, src_mask):\n",
" \n",
" #trg = [batch size, trg len, hid dim]\n",
" #enc_src = [batch size, src len, hid dim]\n",
" #trg_mask = [batch size, trg len]\n",
" #src_mask = [batch size, src len]\n",
" \n",
" #self attention\n",
" _trg, _ = self.self_attention(trg, trg, trg, trg_mask)\n",
" \n",
" #dropout, residual connection and layer norm\n",
" trg = self.self_attn_layer_norm(trg + self.dropout(_trg))\n",
" \n",
" #trg = [batch size, trg len, hid dim]\n",
" \n",
" #encoder attention\n",
" _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask)\n",
" \n",
" #dropout, residual connection and layer norm\n",
" trg = self.enc_attn_layer_norm(trg + self.dropout(_trg))\n",
" \n",
" #trg = [batch size, trg len, hid dim]\n",
" \n",
" #positionwise feedforward\n",
" _trg = self.positionwise_feedforward(trg)\n",
" \n",
" #dropout, residual and layer norm\n",
" trg = self.ff_layer_norm(trg + self.dropout(_trg))\n",
" \n",
" #trg = [batch size, trg len, hid dim]\n",
" #attention = [batch size, n heads, trg len, src len]\n",
" \n",
" return trg, attention\n",
" \n",
" \n",
"class Seq2Seq(nn.Module):\n",
" def __init__(self, \n",
" encoder, \n",
" decoder, \n",
" src_pad_idx, \n",
" trg_pad_idx, \n",
" device):\n",
" super().__init__()\n",
" \n",
" self.encoder = encoder\n",
" self.decoder = decoder\n",
" self.src_pad_idx = src_pad_idx\n",
" self.trg_pad_idx = trg_pad_idx\n",
" self.device = device\n",
" \n",
" def make_src_mask(self, src):\n",
" \n",
" #src = [batch size, src len]\n",
" \n",
" src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)\n",
"\n",
" #src_mask = [batch size, 1, 1, src len]\n",
"\n",
" return src_mask\n",
" \n",
" def make_trg_mask(self, trg):\n",
" \n",
" #trg = [batch size, trg len]\n",
" \n",
" trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)\n",
" \n",
" #trg_pad_mask = [batch size, 1, 1, trg len]\n",
" \n",
" trg_len = trg.shape[1]\n",
" \n",
" trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device)).bool()\n",
" \n",
" #trg_sub_mask = [trg len, trg len]\n",
" \n",
" trg_mask = trg_pad_mask & trg_sub_mask\n",
" \n",
" #trg_mask = [batch size, 1, trg len, trg len]\n",
" \n",
" return trg_mask\n",
"\n",
" def forward(self, src, trg):\n",
" \n",
" #src = [batch size, src len]\n",
" #trg = [batch size, trg len]\n",
" \n",
" src_mask = self.make_src_mask(src)\n",
" trg_mask = self.make_trg_mask(trg)\n",
" \n",
" #src_mask = [batch size, 1, 1, src len]\n",
" #trg_mask = [batch size, 1, trg len, trg len]\n",
" \n",
" enc_src = self.encoder(src, src_mask)\n",
" \n",
" #enc_src = [batch size, src len, hid dim]\n",
" \n",
" output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)\n",
" \n",
" #output = [batch size, trg len, output dim]\n",
" #attention = [batch size, n heads, trg len, src len]\n",
" \n",
" return output, attention"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"code_folding": [
0
],
"colab_type": "code",
"id": "ZpUuq8zxy5qY",
"colab": {}
},
"source": [
"INPUT_DIM = len(SRC.vocab)\n",
"OUTPUT_DIM = len(TRG.vocab)\n",
"HID_DIM = 256\n",
"ENC_LAYERS = 3\n",
"DEC_LAYERS = 3\n",
"ENC_HEADS = 8\n",
"DEC_HEADS = 8\n",
"ENC_PF_DIM = 512\n",
"DEC_PF_DIM = 512\n",
"ENC_DROPOUT = 0.1\n",
"DEC_DROPOUT = 0.1\n",
"\n",
"enc = Encoder(INPUT_DIM, \n",
" HID_DIM, \n",
" ENC_LAYERS, \n",
" ENC_HEADS, \n",
" ENC_PF_DIM, \n",
" ENC_DROPOUT, \n",
" device)\n",
"\n",
"dec = Decoder(OUTPUT_DIM, \n",
" HID_DIM, \n",
" DEC_LAYERS, \n",
" DEC_HEADS, \n",
" DEC_PF_DIM, \n",
" DEC_DROPOUT, \n",
" device)\n",
"\n",
"SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]\n",
"TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "ty8cwecPy5qb",
"colab": {}
},
"source": [
"model = Seq2Seq(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "jLZkRVgwEqH3",
"colab_type": "text"
},
"source": [
"# Model Params"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "loGDyC-ElpZY",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "83f0b461-ade1-4aa1-90b7-e5a8456b3ff6"
},
"source": [
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"print(f'The model has {count_parameters(model):,} trainable parameters')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"The model has 12,506,137 trainable parameters\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "XcaxdCBJEqH7",
"colab_type": "code",
"colab": {},
"outputId": "dda74e10-9654-46a5-a661-efdb38a99185"
},
"source": [
"model"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Seq2Seq(\n",
" (encoder): Encoder(\n",
" (tok_embedding): Embedding(8021, 256)\n",
" (pos_embedding): Embedding(100, 256)\n",
" (layers): ModuleList(\n",
" (0): EncoderLayer(\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (ff_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (self_attention): MultiHeadAttentionLayer(\n",
" (fc_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_o): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (positionwise_feedforward): PositionwiseFeedforwardLayer(\n",
" (fc_1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc_2): Linear(in_features=512, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): EncoderLayer(\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (ff_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (self_attention): MultiHeadAttentionLayer(\n",
" (fc_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_o): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (positionwise_feedforward): PositionwiseFeedforwardLayer(\n",
" (fc_1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc_2): Linear(in_features=512, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (2): EncoderLayer(\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (ff_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (self_attention): MultiHeadAttentionLayer(\n",
" (fc_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_o): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (positionwise_feedforward): PositionwiseFeedforwardLayer(\n",
" (fc_1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc_2): Linear(in_features=512, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (decoder): Decoder(\n",
" (tok_embedding): Embedding(12569, 256)\n",
" (pos_embedding): Embedding(100, 256)\n",
" (layers): ModuleList(\n",
" (0): DecoderLayer(\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (enc_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (ff_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (self_attention): MultiHeadAttentionLayer(\n",
" (fc_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_o): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (encoder_attention): MultiHeadAttentionLayer(\n",
" (fc_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_o): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (positionwise_feedforward): PositionwiseFeedforwardLayer(\n",
" (fc_1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc_2): Linear(in_features=512, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): DecoderLayer(\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (enc_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (ff_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (self_attention): MultiHeadAttentionLayer(\n",
" (fc_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_o): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (encoder_attention): MultiHeadAttentionLayer(\n",
" (fc_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_o): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (positionwise_feedforward): PositionwiseFeedforwardLayer(\n",
" (fc_1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc_2): Linear(in_features=512, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (2): DecoderLayer(\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (enc_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (ff_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (self_attention): MultiHeadAttentionLayer(\n",
" (fc_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_o): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (encoder_attention): MultiHeadAttentionLayer(\n",
" (fc_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (fc_o): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (positionwise_feedforward): PositionwiseFeedforwardLayer(\n",
" (fc_1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc_2): Linear(in_features=512, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (fc_out): Linear(in_features=256, out_features=12569, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
")"
]
},
"metadata": {
"tags": []
},
"execution_count": 23
}
]
},
{
"cell_type": "code",
"metadata": {
"scrolled": false,
"id": "0EjtrhPpEqH_",
"colab_type": "code",
"colab": {},
"outputId": "c60e928f-4de8-4351-faef-7cc2cb981c82"
},
"source": [
"from prettytable import PrettyTable\n",
"\n",
"def count_parameters(model):\n",
" table = PrettyTable([\"Modules\", \"Parameters\"])\n",
" total_params = 0\n",
" for name, parameter in model.named_parameters():\n",
" if not parameter.requires_grad: continue\n",
" param = parameter.numel()\n",
" table.add_row([name, param])\n",
" total_params+=param\n",
" print(table)\n",
" print(f\"Total Trainable Params: {total_params}\")\n",
" return total_params\n",
" \n",
"count_parameters(model)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"+-------------------------------------------------------+------------+\n",
"| Modules | Parameters |\n",
"+-------------------------------------------------------+------------+\n",
"| encoder.tok_embedding.weight | 2053376 |\n",
"| encoder.pos_embedding.weight | 25600 |\n",
"| encoder.layers.0.self_attn_layer_norm.weight | 256 |\n",
"| encoder.layers.0.self_attn_layer_norm.bias | 256 |\n",
"| encoder.layers.0.ff_layer_norm.weight | 256 |\n",
"| encoder.layers.0.ff_layer_norm.bias | 256 |\n",
"| encoder.layers.0.self_attention.fc_q.weight | 65536 |\n",
"| encoder.layers.0.self_attention.fc_q.bias | 256 |\n",
"| encoder.layers.0.self_attention.fc_k.weight | 65536 |\n",
"| encoder.layers.0.self_attention.fc_k.bias | 256 |\n",
"| encoder.layers.0.self_attention.fc_v.weight | 65536 |\n",
"| encoder.layers.0.self_attention.fc_v.bias | 256 |\n",
"| encoder.layers.0.self_attention.fc_o.weight | 65536 |\n",
"| encoder.layers.0.self_attention.fc_o.bias | 256 |\n",
"| encoder.layers.0.positionwise_feedforward.fc_1.weight | 131072 |\n",
"| encoder.layers.0.positionwise_feedforward.fc_1.bias | 512 |\n",
"| encoder.layers.0.positionwise_feedforward.fc_2.weight | 131072 |\n",
"| encoder.layers.0.positionwise_feedforward.fc_2.bias | 256 |\n",
"| encoder.layers.1.self_attn_layer_norm.weight | 256 |\n",
"| encoder.layers.1.self_attn_layer_norm.bias | 256 |\n",
"| encoder.layers.1.ff_layer_norm.weight | 256 |\n",
"| encoder.layers.1.ff_layer_norm.bias | 256 |\n",
"| encoder.layers.1.self_attention.fc_q.weight | 65536 |\n",
"| encoder.layers.1.self_attention.fc_q.bias | 256 |\n",
"| encoder.layers.1.self_attention.fc_k.weight | 65536 |\n",
"| encoder.layers.1.self_attention.fc_k.bias | 256 |\n",
"| encoder.layers.1.self_attention.fc_v.weight | 65536 |\n",
"| encoder.layers.1.self_attention.fc_v.bias | 256 |\n",
"| encoder.layers.1.self_attention.fc_o.weight | 65536 |\n",
"| encoder.layers.1.self_attention.fc_o.bias | 256 |\n",
"| encoder.layers.1.positionwise_feedforward.fc_1.weight | 131072 |\n",
"| encoder.layers.1.positionwise_feedforward.fc_1.bias | 512 |\n",
"| encoder.layers.1.positionwise_feedforward.fc_2.weight | 131072 |\n",
"| encoder.layers.1.positionwise_feedforward.fc_2.bias | 256 |\n",
"| encoder.layers.2.self_attn_layer_norm.weight | 256 |\n",
"| encoder.layers.2.self_attn_layer_norm.bias | 256 |\n",
"| encoder.layers.2.ff_layer_norm.weight | 256 |\n",
"| encoder.layers.2.ff_layer_norm.bias | 256 |\n",
"| encoder.layers.2.self_attention.fc_q.weight | 65536 |\n",
"| encoder.layers.2.self_attention.fc_q.bias | 256 |\n",
"| encoder.layers.2.self_attention.fc_k.weight | 65536 |\n",
"| encoder.layers.2.self_attention.fc_k.bias | 256 |\n",
"| encoder.layers.2.self_attention.fc_v.weight | 65536 |\n",
"| encoder.layers.2.self_attention.fc_v.bias | 256 |\n",
"| encoder.layers.2.self_attention.fc_o.weight | 65536 |\n",
"| encoder.layers.2.self_attention.fc_o.bias | 256 |\n",
"| encoder.layers.2.positionwise_feedforward.fc_1.weight | 131072 |\n",
"| encoder.layers.2.positionwise_feedforward.fc_1.bias | 512 |\n",
"| encoder.layers.2.positionwise_feedforward.fc_2.weight | 131072 |\n",
"| encoder.layers.2.positionwise_feedforward.fc_2.bias | 256 |\n",
"| decoder.tok_embedding.weight | 3217664 |\n",
"| decoder.pos_embedding.weight | 25600 |\n",
"| decoder.layers.0.self_attn_layer_norm.weight | 256 |\n",
"| decoder.layers.0.self_attn_layer_norm.bias | 256 |\n",
"| decoder.layers.0.enc_attn_layer_norm.weight | 256 |\n",
"| decoder.layers.0.enc_attn_layer_norm.bias | 256 |\n",
"| decoder.layers.0.ff_layer_norm.weight | 256 |\n",
"| decoder.layers.0.ff_layer_norm.bias | 256 |\n",
"| decoder.layers.0.self_attention.fc_q.weight | 65536 |\n",
"| decoder.layers.0.self_attention.fc_q.bias | 256 |\n",
"| decoder.layers.0.self_attention.fc_k.weight | 65536 |\n",
"| decoder.layers.0.self_attention.fc_k.bias | 256 |\n",
"| decoder.layers.0.self_attention.fc_v.weight | 65536 |\n",
"| decoder.layers.0.self_attention.fc_v.bias | 256 |\n",
"| decoder.layers.0.self_attention.fc_o.weight | 65536 |\n",
"| decoder.layers.0.self_attention.fc_o.bias | 256 |\n",
"| decoder.layers.0.encoder_attention.fc_q.weight | 65536 |\n",
"| decoder.layers.0.encoder_attention.fc_q.bias | 256 |\n",
"| decoder.layers.0.encoder_attention.fc_k.weight | 65536 |\n",
"| decoder.layers.0.encoder_attention.fc_k.bias | 256 |\n",
"| decoder.layers.0.encoder_attention.fc_v.weight | 65536 |\n",
"| decoder.layers.0.encoder_attention.fc_v.bias | 256 |\n",
"| decoder.layers.0.encoder_attention.fc_o.weight | 65536 |\n",
"| decoder.layers.0.encoder_attention.fc_o.bias | 256 |\n",
"| decoder.layers.0.positionwise_feedforward.fc_1.weight | 131072 |\n",
"| decoder.layers.0.positionwise_feedforward.fc_1.bias | 512 |\n",
"| decoder.layers.0.positionwise_feedforward.fc_2.weight | 131072 |\n",
"| decoder.layers.0.positionwise_feedforward.fc_2.bias | 256 |\n",
"| decoder.layers.1.self_attn_layer_norm.weight | 256 |\n",
"| decoder.layers.1.self_attn_layer_norm.bias | 256 |\n",
"| decoder.layers.1.enc_attn_layer_norm.weight | 256 |\n",
"| decoder.layers.1.enc_attn_layer_norm.bias | 256 |\n",
"| decoder.layers.1.ff_layer_norm.weight | 256 |\n",
"| decoder.layers.1.ff_layer_norm.bias | 256 |\n",
"| decoder.layers.1.self_attention.fc_q.weight | 65536 |\n",
"| decoder.layers.1.self_attention.fc_q.bias | 256 |\n",
"| decoder.layers.1.self_attention.fc_k.weight | 65536 |\n",
"| decoder.layers.1.self_attention.fc_k.bias | 256 |\n",
"| decoder.layers.1.self_attention.fc_v.weight | 65536 |\n",
"| decoder.layers.1.self_attention.fc_v.bias | 256 |\n",
"| decoder.layers.1.self_attention.fc_o.weight | 65536 |\n",
"| decoder.layers.1.self_attention.fc_o.bias | 256 |\n",
"| decoder.layers.1.encoder_attention.fc_q.weight | 65536 |\n",
"| decoder.layers.1.encoder_attention.fc_q.bias | 256 |\n",
"| decoder.layers.1.encoder_attention.fc_k.weight | 65536 |\n",
"| decoder.layers.1.encoder_attention.fc_k.bias | 256 |\n",
"| decoder.layers.1.encoder_attention.fc_v.weight | 65536 |\n",
"| decoder.layers.1.encoder_attention.fc_v.bias | 256 |\n",
"| decoder.layers.1.encoder_attention.fc_o.weight | 65536 |\n",
"| decoder.layers.1.encoder_attention.fc_o.bias | 256 |\n",
"| decoder.layers.1.positionwise_feedforward.fc_1.weight | 131072 |\n",
"| decoder.layers.1.positionwise_feedforward.fc_1.bias | 512 |\n",
"| decoder.layers.1.positionwise_feedforward.fc_2.weight | 131072 |\n",
"| decoder.layers.1.positionwise_feedforward.fc_2.bias | 256 |\n",
"| decoder.layers.2.self_attn_layer_norm.weight | 256 |\n",
"| decoder.layers.2.self_attn_layer_norm.bias | 256 |\n",
"| decoder.layers.2.enc_attn_layer_norm.weight | 256 |\n",
"| decoder.layers.2.enc_attn_layer_norm.bias | 256 |\n",
"| decoder.layers.2.ff_layer_norm.weight | 256 |\n",
"| decoder.layers.2.ff_layer_norm.bias | 256 |\n",
"| decoder.layers.2.self_attention.fc_q.weight | 65536 |\n",
"| decoder.layers.2.self_attention.fc_q.bias | 256 |\n",
"| decoder.layers.2.self_attention.fc_k.weight | 65536 |\n",
"| decoder.layers.2.self_attention.fc_k.bias | 256 |\n",
"| decoder.layers.2.self_attention.fc_v.weight | 65536 |\n",
"| decoder.layers.2.self_attention.fc_v.bias | 256 |\n",
"| decoder.layers.2.self_attention.fc_o.weight | 65536 |\n",
"| decoder.layers.2.self_attention.fc_o.bias | 256 |\n",
"| decoder.layers.2.encoder_attention.fc_q.weight | 65536 |\n",
"| decoder.layers.2.encoder_attention.fc_q.bias | 256 |\n",
"| decoder.layers.2.encoder_attention.fc_k.weight | 65536 |\n",
"| decoder.layers.2.encoder_attention.fc_k.bias | 256 |\n",
"| decoder.layers.2.encoder_attention.fc_v.weight | 65536 |\n",
"| decoder.layers.2.encoder_attention.fc_v.bias | 256 |\n",
"| decoder.layers.2.encoder_attention.fc_o.weight | 65536 |\n",
"| decoder.layers.2.encoder_attention.fc_o.bias | 256 |\n",
"| decoder.layers.2.positionwise_feedforward.fc_1.weight | 131072 |\n",
"| decoder.layers.2.positionwise_feedforward.fc_1.bias | 512 |\n",
"| decoder.layers.2.positionwise_feedforward.fc_2.weight | 131072 |\n",
"| decoder.layers.2.positionwise_feedforward.fc_2.bias | 256 |\n",
"| decoder.fc_out.weight | 3217664 |\n",
"| decoder.fc_out.bias | 12569 |\n",
"+-------------------------------------------------------+------------+\n",
"Total Trainable Params: 12506137\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"12506137"
]
},
"metadata": {
"tags": []
},
"execution_count": 24
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "e7rHYuTKEqIC",
"colab_type": "text"
},
"source": [
"# Method 1 (Load Weights)\n",
"\n",
"The above Total Trainable Params are 12506137 but in google colab (runtime=None, device='cpu') Total Trainable Params: 12490234\n",
"\n",
"So unable to load model weights"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "07fMQajdlpUq",
"scrolled": false,
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "ba892bac-1dc1-4fe2-955a-934ac5056385"
},
"source": [
"model.load_state_dict(torch.load(f\"{model_name}_2.pt\", map_location=device), strict=False)"
],
"execution_count": null,
"outputs": [
{
"output_type": "error",
"ename": "RuntimeError",
"evalue": "Error(s) in loading state_dict for Seq2Seq:\n\tsize mismatch for decoder.tok_embedding.weight: copying a param with shape torch.Size([12538, 256]) from checkpoint, the shape in current model is torch.Size([12569, 256]).\n\tsize mismatch for decoder.fc_out.weight: copying a param with shape torch.Size([12538, 256]) from checkpoint, the shape in current model is torch.Size([12569, 256]).\n\tsize mismatch for decoder.fc_out.bias: copying a param with shape torch.Size([12538]) from checkpoint, the shape in current model is torch.Size([12569]).",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-25-977503f88f12>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34mf\"{model_name}_2.pt\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstrict\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[1;32mc:\\anaconda\\envs\\lang_trans\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36mload_state_dict\u001b[1;34m(self, state_dict, strict)\u001b[0m\n\u001b[0;32m 845\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0merror_msgs\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m>\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 846\u001b[0m raise RuntimeError('Error(s) in loading state_dict for {}:\\n\\t{}'.format(\n\u001b[1;32m--> 847\u001b[1;33m self.__class__.__name__, \"\\n\\t\".join(error_msgs)))\n\u001b[0m\u001b[0;32m 848\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0m_IncompatibleKeys\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmissing_keys\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0munexpected_keys\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 849\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for Seq2Seq:\n\tsize mismatch for decoder.tok_embedding.weight: copying a param with shape torch.Size([12538, 256]) from checkpoint, the shape in current model is torch.Size([12569, 256]).\n\tsize mismatch for decoder.fc_out.weight: copying a param with shape torch.Size([12538, 256]) from checkpoint, the shape in current model is torch.Size([12569, 256]).\n\tsize mismatch for decoder.fc_out.bias: copying a param with shape torch.Size([12538]) from checkpoint, the shape in current model is torch.Size([12569])."
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AkNGWg2QEqIG",
"colab_type": "text"
},
"source": [
"# Method 2 (Load full model)\n",
"\n",
"It is loading here but giving error in next cell, see the params changes here again"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "rJvLhc3Zy5qd",
"colab": {},
"outputId": "9406c490-5347-4600-bd19-ce7d323e869e"
},
"source": [
"model = torch.load(f\"{model_name}.pth\", map_location=torch.device('cpu'))\n",
"\n",
"count_parameters(model)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"+-------------------------------------------------------+------------+\n",
"| Modules | Parameters |\n",
"+-------------------------------------------------------+------------+\n",
"| encoder.tok_embedding.weight | 1552128 |\n",
"| encoder.pos_embedding.weight | 25600 |\n",
"| encoder.layers.0.self_attn_layer_norm.weight | 256 |\n",
"| encoder.layers.0.self_attn_layer_norm.bias | 256 |\n",
"| encoder.layers.0.ff_layer_norm.weight | 256 |\n",
"| encoder.layers.0.ff_layer_norm.bias | 256 |\n",
"| encoder.layers.0.self_attention.fc_q.weight | 65536 |\n",
"| encoder.layers.0.self_attention.fc_q.bias | 256 |\n",
"| encoder.layers.0.self_attention.fc_k.weight | 65536 |\n",
"| encoder.layers.0.self_attention.fc_k.bias | 256 |\n",
"| encoder.layers.0.self_attention.fc_v.weight | 65536 |\n",
"| encoder.layers.0.self_attention.fc_v.bias | 256 |\n",
"| encoder.layers.0.self_attention.fc_o.weight | 65536 |\n",
"| encoder.layers.0.self_attention.fc_o.bias | 256 |\n",
"| encoder.layers.0.positionwise_feedforward.fc_1.weight | 131072 |\n",
"| encoder.layers.0.positionwise_feedforward.fc_1.bias | 512 |\n",
"| encoder.layers.0.positionwise_feedforward.fc_2.weight | 131072 |\n",
"| encoder.layers.0.positionwise_feedforward.fc_2.bias | 256 |\n",
"| encoder.layers.1.self_attn_layer_norm.weight | 256 |\n",
"| encoder.layers.1.self_attn_layer_norm.bias | 256 |\n",
"| encoder.layers.1.ff_layer_norm.weight | 256 |\n",
"| encoder.layers.1.ff_layer_norm.bias | 256 |\n",
"| encoder.layers.1.self_attention.fc_q.weight | 65536 |\n",
"| encoder.layers.1.self_attention.fc_q.bias | 256 |\n",
"| encoder.layers.1.self_attention.fc_k.weight | 65536 |\n",
"| encoder.layers.1.self_attention.fc_k.bias | 256 |\n",
"| encoder.layers.1.self_attention.fc_v.weight | 65536 |\n",
"| encoder.layers.1.self_attention.fc_v.bias | 256 |\n",
"| encoder.layers.1.self_attention.fc_o.weight | 65536 |\n",
"| encoder.layers.1.self_attention.fc_o.bias | 256 |\n",
"| encoder.layers.1.positionwise_feedforward.fc_1.weight | 131072 |\n",
"| encoder.layers.1.positionwise_feedforward.fc_1.bias | 512 |\n",
"| encoder.layers.1.positionwise_feedforward.fc_2.weight | 131072 |\n",
"| encoder.layers.1.positionwise_feedforward.fc_2.bias | 256 |\n",
"| encoder.layers.2.self_attn_layer_norm.weight | 256 |\n",
"| encoder.layers.2.self_attn_layer_norm.bias | 256 |\n",
"| encoder.layers.2.ff_layer_norm.weight | 256 |\n",
"| encoder.layers.2.ff_layer_norm.bias | 256 |\n",
"| encoder.layers.2.self_attention.fc_q.weight | 65536 |\n",
"| encoder.layers.2.self_attention.fc_q.bias | 256 |\n",
"| encoder.layers.2.self_attention.fc_k.weight | 65536 |\n",
"| encoder.layers.2.self_attention.fc_k.bias | 256 |\n",
"| encoder.layers.2.self_attention.fc_v.weight | 65536 |\n",
"| encoder.layers.2.self_attention.fc_v.bias | 256 |\n",
"| encoder.layers.2.self_attention.fc_o.weight | 65536 |\n",
"| encoder.layers.2.self_attention.fc_o.bias | 256 |\n",
"| encoder.layers.2.positionwise_feedforward.fc_1.weight | 131072 |\n",
"| encoder.layers.2.positionwise_feedforward.fc_1.bias | 512 |\n",
"| encoder.layers.2.positionwise_feedforward.fc_2.weight | 131072 |\n",
"| encoder.layers.2.positionwise_feedforward.fc_2.bias | 256 |\n",
"| decoder.tok_embedding.weight | 2380288 |\n",
"| decoder.pos_embedding.weight | 25600 |\n",
"| decoder.layers.0.self_attn_layer_norm.weight | 256 |\n",
"| decoder.layers.0.self_attn_layer_norm.bias | 256 |\n",
"| decoder.layers.0.enc_attn_layer_norm.weight | 256 |\n",
"| decoder.layers.0.enc_attn_layer_norm.bias | 256 |\n",
"| decoder.layers.0.ff_layer_norm.weight | 256 |\n",
"| decoder.layers.0.ff_layer_norm.bias | 256 |\n",
"| decoder.layers.0.self_attention.fc_q.weight | 65536 |\n",
"| decoder.layers.0.self_attention.fc_q.bias | 256 |\n",
"| decoder.layers.0.self_attention.fc_k.weight | 65536 |\n",
"| decoder.layers.0.self_attention.fc_k.bias | 256 |\n",
"| decoder.layers.0.self_attention.fc_v.weight | 65536 |\n",
"| decoder.layers.0.self_attention.fc_v.bias | 256 |\n",
"| decoder.layers.0.self_attention.fc_o.weight | 65536 |\n",
"| decoder.layers.0.self_attention.fc_o.bias | 256 |\n",
"| decoder.layers.0.encoder_attention.fc_q.weight | 65536 |\n",
"| decoder.layers.0.encoder_attention.fc_q.bias | 256 |\n",
"| decoder.layers.0.encoder_attention.fc_k.weight | 65536 |\n",
"| decoder.layers.0.encoder_attention.fc_k.bias | 256 |\n",
"| decoder.layers.0.encoder_attention.fc_v.weight | 65536 |\n",
"| decoder.layers.0.encoder_attention.fc_v.bias | 256 |\n",
"| decoder.layers.0.encoder_attention.fc_o.weight | 65536 |\n",
"| decoder.layers.0.encoder_attention.fc_o.bias | 256 |\n",
"| decoder.layers.0.positionwise_feedforward.fc_1.weight | 131072 |\n",
"| decoder.layers.0.positionwise_feedforward.fc_1.bias | 512 |\n",
"| decoder.layers.0.positionwise_feedforward.fc_2.weight | 131072 |\n",
"| decoder.layers.0.positionwise_feedforward.fc_2.bias | 256 |\n",
"| decoder.layers.1.self_attn_layer_norm.weight | 256 |\n",
"| decoder.layers.1.self_attn_layer_norm.bias | 256 |\n",
"| decoder.layers.1.enc_attn_layer_norm.weight | 256 |\n",
"| decoder.layers.1.enc_attn_layer_norm.bias | 256 |\n",
"| decoder.layers.1.ff_layer_norm.weight | 256 |\n",
"| decoder.layers.1.ff_layer_norm.bias | 256 |\n",
"| decoder.layers.1.self_attention.fc_q.weight | 65536 |\n",
"| decoder.layers.1.self_attention.fc_q.bias | 256 |\n",
"| decoder.layers.1.self_attention.fc_k.weight | 65536 |\n",
"| decoder.layers.1.self_attention.fc_k.bias | 256 |\n",
"| decoder.layers.1.self_attention.fc_v.weight | 65536 |\n",
"| decoder.layers.1.self_attention.fc_v.bias | 256 |\n",
"| decoder.layers.1.self_attention.fc_o.weight | 65536 |\n",
"| decoder.layers.1.self_attention.fc_o.bias | 256 |\n",
"| decoder.layers.1.encoder_attention.fc_q.weight | 65536 |\n",
"| decoder.layers.1.encoder_attention.fc_q.bias | 256 |\n",
"| decoder.layers.1.encoder_attention.fc_k.weight | 65536 |\n",
"| decoder.layers.1.encoder_attention.fc_k.bias | 256 |\n",
"| decoder.layers.1.encoder_attention.fc_v.weight | 65536 |\n",
"| decoder.layers.1.encoder_attention.fc_v.bias | 256 |\n",
"| decoder.layers.1.encoder_attention.fc_o.weight | 65536 |\n",
"| decoder.layers.1.encoder_attention.fc_o.bias | 256 |\n",
"| decoder.layers.1.positionwise_feedforward.fc_1.weight | 131072 |\n",
"| decoder.layers.1.positionwise_feedforward.fc_1.bias | 512 |\n",
"| decoder.layers.1.positionwise_feedforward.fc_2.weight | 131072 |\n",
"| decoder.layers.1.positionwise_feedforward.fc_2.bias | 256 |\n",
"| decoder.layers.2.self_attn_layer_norm.weight | 256 |\n",
"| decoder.layers.2.self_attn_layer_norm.bias | 256 |\n",
"| decoder.layers.2.enc_attn_layer_norm.weight | 256 |\n",
"| decoder.layers.2.enc_attn_layer_norm.bias | 256 |\n",
"| decoder.layers.2.ff_layer_norm.weight | 256 |\n",
"| decoder.layers.2.ff_layer_norm.bias | 256 |\n",
"| decoder.layers.2.self_attention.fc_q.weight | 65536 |\n",
"| decoder.layers.2.self_attention.fc_q.bias | 256 |\n",
"| decoder.layers.2.self_attention.fc_k.weight | 65536 |\n",
"| decoder.layers.2.self_attention.fc_k.bias | 256 |\n",
"| decoder.layers.2.self_attention.fc_v.weight | 65536 |\n",
"| decoder.layers.2.self_attention.fc_v.bias | 256 |\n",
"| decoder.layers.2.self_attention.fc_o.weight | 65536 |\n",
"| decoder.layers.2.self_attention.fc_o.bias | 256 |\n",
"| decoder.layers.2.encoder_attention.fc_q.weight | 65536 |\n",
"| decoder.layers.2.encoder_attention.fc_q.bias | 256 |\n",
"| decoder.layers.2.encoder_attention.fc_k.weight | 65536 |\n",
"| decoder.layers.2.encoder_attention.fc_k.bias | 256 |\n",
"| decoder.layers.2.encoder_attention.fc_v.weight | 65536 |\n",
"| decoder.layers.2.encoder_attention.fc_v.bias | 256 |\n",
"| decoder.layers.2.encoder_attention.fc_o.weight | 65536 |\n",
"| decoder.layers.2.encoder_attention.fc_o.bias | 256 |\n",
"| decoder.layers.2.positionwise_feedforward.fc_1.weight | 131072 |\n",
"| decoder.layers.2.positionwise_feedforward.fc_1.bias | 512 |\n",
"| decoder.layers.2.positionwise_feedforward.fc_2.weight | 131072 |\n",
"| decoder.layers.2.positionwise_feedforward.fc_2.bias | 256 |\n",
"| decoder.fc_out.weight | 2380288 |\n",
"| decoder.fc_out.bias | 9298 |\n",
"+-------------------------------------------------------+------------+\n",
"Total Trainable Params: 10326866\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"10326866"
]
},
"metadata": {
"tags": []
},
"execution_count": 29
}
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "QskLHNo5y5qh"
},
"source": [
"# Prediction"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "A_7Lj1iwy5qh",
"colab": {}
},
"source": [
"def translate_sentence(sentence, src_field, trg_field, model, device, max_len = 50):\n",
" \n",
" model.eval()\n",
" \n",
" if isinstance(sentence, str):\n",
" nlp = spacy.load(lang2_name[1:])\n",
" tokens = [token.text.lower() for token in nlp(sentence)]\n",
" else:\n",
" tokens = [token.lower() for token in sentence]\n",
"\n",
" tokens = [src_field.init_token] + tokens + [src_field.eos_token]\n",
"\n",
" print(tokens)\n",
" \n",
" src_indexes = [src_field.vocab.stoi[token] for token in tokens]\n",
"\n",
" src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)\n",
" \n",
" src_mask = model.make_src_mask(src_tensor)\n",
" \n",
" with torch.no_grad():\n",
" enc_src = model.encoder(src_tensor, src_mask)\n",
"\n",
" trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]\n",
"\n",
" for i in range(max_len):\n",
"\n",
" trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)\n",
"\n",
" trg_mask = model.make_trg_mask(trg_tensor)\n",
" \n",
" with torch.no_grad():\n",
" output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)\n",
" \n",
" pred_token = output.argmax(2)[:,-1].item()\n",
" \n",
" trg_indexes.append(pred_token)\n",
"\n",
" if pred_token == trg_field.vocab.stoi[trg_field.eos_token]:\n",
" break\n",
" \n",
" trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes]\n",
" \n",
" return trg_tokens[1:], attention"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "MZDVMsewy5qj",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
},
"outputId": "e668cc00-850a-4461-ca88-4f3d04aa5ef3"
},
"source": [
"src = \"AlwaysOn is an app for every device with an AMOLED or OLED display\"\n",
"\n",
"src = src.split(\" \")\n",
"\n",
"translation, attention = translate_sentence(src, SRC, TRG, model, device)\n",
"\n",
"print(f'predicted trg = {translation}')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"['<sos>', 'alwayson', 'is', 'an', 'app', 'for', 'every', 'device', 'with', 'an', 'amoled', 'or', 'oled', 'display', '<eos>']\n"
],
"name": "stdout"
},
{
"output_type": "error",
"ename": "AssertionError",
"evalue": "Torch not compiled with CUDA enabled",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mAssertionError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-31-ff8428ee5832>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0msrc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msrc\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\" \"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 5\u001b[1;33m \u001b[0mtranslation\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mattention\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtranslate_sentence\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msrc\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mSRC\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mTRG\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 6\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 7\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34mf'predicted trg = {translation}'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m<ipython-input-30-5d12c49b1bf1>\u001b[0m in \u001b[0;36mtranslate_sentence\u001b[1;34m(sentence, src_field, trg_field, model, device, max_len)\u001b[0m\n\u001b[0;32m 20\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 21\u001b[0m \u001b[1;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 22\u001b[1;33m \u001b[0menc_src\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mencoder\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msrc_tensor\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msrc_mask\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 23\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 24\u001b[0m \u001b[0mtrg_indexes\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mtrg_field\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvocab\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstoi\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mtrg_field\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minit_token\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mc:\\anaconda\\envs\\lang_trans\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 548\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 549\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 550\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 551\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 552\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m<ipython-input-19-5ecaa062b686>\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, src, src_mask)\u001b[0m\n\u001b[0;32m 37\u001b[0m \u001b[0mdevice\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'cuda'\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_available\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32melse\u001b[0m \u001b[1;34m'cpu'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 38\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 39\u001b[1;33m \u001b[0mpos\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msrc_len\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrepeat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 40\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 41\u001b[0m \u001b[1;31m#pos = [batch size, src len]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mc:\\anaconda\\envs\\lang_trans\\lib\\site-packages\\torch\\cuda\\__init__.py\u001b[0m in \u001b[0;36m_lazy_init\u001b[1;34m()\u001b[0m\n\u001b[0;32m 147\u001b[0m raise RuntimeError(\n\u001b[0;32m 148\u001b[0m \"Cannot re-initialize CUDA in forked subprocess. \" + msg)\n\u001b[1;32m--> 149\u001b[1;33m \u001b[0m_check_driver\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 150\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0m_cudart\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 151\u001b[0m raise AssertionError(\n",
"\u001b[1;32mc:\\anaconda\\envs\\lang_trans\\lib\\site-packages\\torch\\cuda\\__init__.py\u001b[0m in \u001b[0;36m_check_driver\u001b[1;34m()\u001b[0m\n\u001b[0;32m 45\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_check_driver\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 46\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'_cuda_isDriverSufficient'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 47\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mAssertionError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Torch not compiled with CUDA enabled\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 48\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_cuda_isDriverSufficient\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 49\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_cuda_getDriverVersion\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mAssertionError\u001b[0m: Torch not compiled with CUDA enabled"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "CtlSMBWi6FDS"
},
"source": [
"# Visualization"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "Tn5Eo_-98ccc",
"colab": {}
},
"source": [
"import matplotlib.pyplot as plt\n",
"import matplotlib.ticker as ticker"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "sZDvG431-6cQ",
"colab": {}
},
"source": [
"li = ['binary', 'cool']"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "tdiewTRz6i68",
"colab": {}
},
"source": [
"def display_attention(sentence, translation, attention, n_heads = 8, n_rows = 4, n_cols = 2, clr=False):\n",
" \n",
" if clr == True:\n",
" for clr_code in li:\n",
" assert n_rows * n_cols == n_heads\n",
" \n",
" fig = plt.figure(figsize=(15,25))\n",
" \n",
" for i in range(n_heads):\n",
" \n",
" ax = fig.add_subplot(n_rows, n_cols, i+1)\n",
" \n",
" _attention = attention.squeeze(0)[i].cpu().detach().numpy()\n",
"\n",
"\n",
" cax = ax.matshow(_attention, cmap=clr_code)\n",
"\n",
" ax.tick_params(labelsize=12)\n",
" ax.set_xticklabels(['']+['<sos>']+[t.lower() for t in sentence]+['<eos>'], \n",
" rotation=45)\n",
" ax.set_yticklabels(['']+translation)\n",
"\n",
" ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n",
" ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n",
"\n",
" plt.show()\n",
" plt.close()\n",
" else:\n",
" assert n_rows * n_cols == n_heads\n",
" \n",
" fig = plt.figure(figsize=(10,20))\n",
" \n",
" for i in range(n_heads):\n",
" \n",
" ax = fig.add_subplot(n_rows, n_cols, i+1)\n",
" \n",
" _attention = attention.squeeze(0)[i].cpu().detach().numpy()\n",
"\n",
"\n",
" cax = ax.matshow(_attention, cmap='gist_heat')\n",
"\n",
" ax.tick_params(labelsize=12)\n",
" ax.set_xticklabels(['']+['<sos>']+[t.lower() for t in sentence]+['<eos>'], \n",
" rotation=45)\n",
" ax.set_yticklabels(['']+translation)\n",
"\n",
" ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n",
" ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n",
"\n",
" plt.show()\n",
" plt.close()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "0WzIjPxr7E3v",
"colab": {}
},
"source": [
"display_attention(src, translation, attention, clr=True)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "sep9LyC18XYF",
"colab": {}
},
"source": [
"example_idx = 6\n",
"\n",
"src = vars(valid_data.examples[example_idx])['src']\n",
"trg = vars(valid_data.examples[example_idx])['trg']\n",
"\n",
"print(f'src = {src}')\n",
"print(f'trg = {trg}')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "THcyVb-f8yfS",
"colab": {}
},
"source": [
"translation, attention = translate_sentence(src, SRC, TRG, model, device)\n",
"\n",
"print(f'predicted trg = {translation}')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "4msONCaR81on",
"colab": {}
},
"source": [
"display_attention(src, translation, attention, clr=True)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "YS5XjAiE84RM",
"colab": {}
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment