-
-
Save rajy4683/8caed200ea00ea663ad63201e9852dfc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Encoder(nn.Module): | |
def __init__(self, | |
input_dim, | |
emb_dim, | |
hid_dim, | |
n_layers, | |
kernel_size, | |
dropout, | |
device, | |
max_length = 100): | |
super().__init__() | |
assert kernel_size % 2 == 1, "Kernel size must be odd!" | |
self.device = device | |
### Post summation we scale to prevent large values | |
self.scale = torch.sqrt(torch.FloatTensor([0.5])).to(device) | |
### Token embedding layer | |
self.tok_embedding = nn.Embedding(input_dim, emb_dim) | |
### Positional embedding layer | |
self.pos_embedding = nn.Embedding(max_length, emb_dim) | |
### Applied before conv blocks | |
self.emb2hid = nn.Linear(emb_dim, hid_dim) | |
### Applied after conv blocks | |
self.hid2emb = nn.Linear(hid_dim, emb_dim) | |
### Conv blocks | |
self.convs = nn.ModuleList([nn.Conv1d(in_channels = hid_dim, | |
out_channels = 2 * hid_dim, ### GLU will reduce the dimensions | |
kernel_size = kernel_size, | |
padding = (kernel_size - 1) // 2) ### Output size must remain same | |
for _ in range(n_layers)]) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, src): | |
#src = [batch size, src len] | |
batch_size = src.shape[0] | |
src_len = src.shape[1] | |
#create position tensor | |
pos = torch.arange(0, src_len) | |
.unsqueeze(0) | |
.repeat(batch_size, 1) | |
.to(self.device) | |
#pos = [[0, 1, 2, 3, ..., src len - 1]...] | |
#pos = [batch size, src len] | |
#embed tokens and positions | |
tok_embedded = self.tok_embedding(src) | |
pos_embedded = self.pos_embedding(pos) | |
#tok_embedded = pos_embedded = [batch size, src len, emb dim] | |
#combine embeddings by elementwise summing | |
embedded = self.dropout(tok_embedded + pos_embedded) | |
#embedded = [batch size, src len, emb dim] | |
#pass embedded through linear layer to convert from emb dim to hid dim | |
conv_input = self.emb2hid(embedded) | |
#conv_input = [batch size, src len, hid dim] | |
#permute for convolutional layer | |
conv_input = conv_input.permute(0, 2, 1) | |
#conv_input = [batch size, hid dim, src len] | |
#begin convolutional blocks... | |
for i, conv in enumerate(self.convs): | |
#pass through convolutional layer | |
conved = conv(self.dropout(conv_input)) | |
#conved = [batch size, 2 * hid dim, src len] | |
#pass through GLU activation function | |
conved = F.glu(conved, dim = 1) | |
#conved = [batch size, hid dim, src len] | |
#apply residual connection | |
conved = (conved + conv_input) * self.scale | |
#conved = [batch size, hid dim, src len] | |
#set conv_input to conved for next loop iteration | |
conv_input = conved | |
#...end convolutional blocks | |
#permute and convert back to emb dim | |
conved = self.hid2emb(conved.permute(0, 2, 1)) | |
#conved = [batch size, src len, emb dim] | |
#elementwise sum output (conved) and input (embedded) to be used for attention | |
combined = (conved + embedded) * self.scale | |
#combined = [batch size, src len, emb dim] | |
return conved, combined |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment