Skip to content

Instantly share code, notes, and snippets.

@ammesatyajit
ammesatyajit / imports.py
Last active February 18, 2021 23:53
Machine Translation Transformer
import torch
import torch.nn as nn
import torch.optim as optim
import torchtext
from torchtext.datasets import TranslationDataset, Multi30k, IWSLT, WMT14
from torchtext.data import Field, BucketIterator
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
@ammesatyajit
ammesatyajit / random_seed.py
Created February 18, 2021 23:51
Machine Translation Transformer
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
@ammesatyajit
ammesatyajit / spacy_load.py
Created February 18, 2021 23:52
Machine Translation Transformer
!python -m spacy download en
!python -m spacy download de
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')
@ammesatyajit
ammesatyajit / tokenize_methods.py
Created February 18, 2021 23:54
Machine Translation Transformer
def tokenize_de(text):
"""
Tokenizes German text from a string into a list of strings
"""
return [tok.text for tok in spacy_de.tokenizer(text)]
def tokenize_en(text):
"""
Tokenizes English text from a string into a list of strings
"""
@ammesatyajit
ammesatyajit / fields.py
Created February 18, 2021 23:55
Machine Translation Transformer
SRC = Field(tokenize = tokenize_en,
init_token = '<sos>',
eos_token = '<eos>',
lower = True,
batch_first = True)
TRG = Field(tokenize = tokenize_de,
init_token = '<sos>',
eos_token = '<eos>',
lower = True,
@ammesatyajit
ammesatyajit / load_data.py
Last active February 19, 2021 00:00
Machine Translation Transformer
train_data, valid_data, test_data = IWSLT.splits(exts = ('.en', '.de'), fields = (SRC, TRG))
SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)
@ammesatyajit
ammesatyajit / iterator_init.py
Created February 18, 2021 23:59
Machine Translation Transformer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 128
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
(train_data, valid_data, test_data),
batch_size = BATCH_SIZE,
device = device)
@ammesatyajit
ammesatyajit / encoder_layer.py
Created February 19, 2021 00:28
Machine Translation Transformer
class EncoderLayer(nn.Module):
def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
super().__init__()
self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
self.ff_layer_norm = nn.LayerNorm(hid_dim)
self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
self.dropout = nn.Dropout(dropout)
@ammesatyajit
ammesatyajit / positionwise_ff_layer.py
Last active February 19, 2021 00:31
Machine Translation Transformer
class PositionwiseFeedforwardLayer(nn.Module):
def __init__(self, hid_dim, pf_dim, dropout):
super().__init__()
self.fc_1 = nn.Linear(hid_dim, pf_dim)
self.fc_2 = nn.Linear(pf_dim, hid_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
@ammesatyajit
ammesatyajit / multihead_attn.py
Last active February 19, 2021 07:37
Machine Translation Transformer
class MultiHeadAttentionLayer(nn.Module):
def __init__(self, hid_dim, n_heads, dropout, device):
super().__init__()
assert hid_dim % n_heads == 0
self.hid_dim = hid_dim
self.n_heads = n_heads
self.head_dim = hid_dim // n_heads