This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torchtext | |
| from torchtext.datasets import TranslationDataset, Multi30k, IWSLT, WMT14 | |
| from torchtext.data import Field, BucketIterator | |
| import matplotlib.pyplot as plt | |
| import matplotlib.ticker as ticker |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| SEED = 1234 | |
| random.seed(SEED) | |
| np.random.seed(SEED) | |
| torch.manual_seed(SEED) | |
| torch.cuda.manual_seed(SEED) | |
| torch.backends.cudnn.deterministic = True |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| !python -m spacy download en | |
| !python -m spacy download de | |
| spacy_de = spacy.load('de') | |
| spacy_en = spacy.load('en') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def tokenize_de(text): | |
| """ | |
| Tokenizes German text from a string into a list of strings | |
| """ | |
| return [tok.text for tok in spacy_de.tokenizer(text)] | |
| def tokenize_en(text): | |
| """ | |
| Tokenizes English text from a string into a list of strings | |
| """ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| SRC = Field(tokenize = tokenize_en, | |
| init_token = '<sos>', | |
| eos_token = '<eos>', | |
| lower = True, | |
| batch_first = True) | |
| TRG = Field(tokenize = tokenize_de, | |
| init_token = '<sos>', | |
| eos_token = '<eos>', | |
| lower = True, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| train_data, valid_data, test_data = IWSLT.splits(exts = ('.en', '.de'), fields = (SRC, TRG)) | |
| SRC.build_vocab(train_data, min_freq = 2) | |
| TRG.build_vocab(train_data, min_freq = 2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| BATCH_SIZE = 128 | |
| train_iterator, valid_iterator, test_iterator = BucketIterator.splits( | |
| (train_data, valid_data, test_data), | |
| batch_size = BATCH_SIZE, | |
| device = device) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class EncoderLayer(nn.Module): | |
| def __init__(self, hid_dim, n_heads, pf_dim, dropout, device): | |
| super().__init__() | |
| self.self_attn_layer_norm = nn.LayerNorm(hid_dim) | |
| self.ff_layer_norm = nn.LayerNorm(hid_dim) | |
| self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) | |
| self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout) | |
| self.dropout = nn.Dropout(dropout) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class PositionwiseFeedforwardLayer(nn.Module): | |
| def __init__(self, hid_dim, pf_dim, dropout): | |
| super().__init__() | |
| self.fc_1 = nn.Linear(hid_dim, pf_dim) | |
| self.fc_2 = nn.Linear(pf_dim, hid_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class MultiHeadAttentionLayer(nn.Module): | |
| def __init__(self, hid_dim, n_heads, dropout, device): | |
| super().__init__() | |
| assert hid_dim % n_heads == 0 | |
| self.hid_dim = hid_dim | |
| self.n_heads = n_heads | |
| self.head_dim = hid_dim // n_heads | |
OlderNewer