Skip to content

Instantly share code, notes, and snippets.

@rajy4683
Last active February 7, 2021 12:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rajy4683/8caed200ea00ea663ad63201e9852dfc to your computer and use it in GitHub Desktop.
Save rajy4683/8caed200ea00ea663ad63201e9852dfc to your computer and use it in GitHub Desktop.
class Encoder(nn.Module):
def __init__(self,
input_dim,
emb_dim,
hid_dim,
n_layers,
kernel_size,
dropout,
device,
max_length = 100):
super().__init__()
assert kernel_size % 2 == 1, "Kernel size must be odd!"
self.device = device
### Post summation we scale to prevent large values
self.scale = torch.sqrt(torch.FloatTensor([0.5])).to(device)
### Token embedding layer
self.tok_embedding = nn.Embedding(input_dim, emb_dim)
### Positional embedding layer
self.pos_embedding = nn.Embedding(max_length, emb_dim)
### Applied before conv blocks
self.emb2hid = nn.Linear(emb_dim, hid_dim)
### Applied after conv blocks
self.hid2emb = nn.Linear(hid_dim, emb_dim)
### Conv blocks
self.convs = nn.ModuleList([nn.Conv1d(in_channels = hid_dim,
out_channels = 2 * hid_dim, ### GLU will reduce the dimensions
kernel_size = kernel_size,
padding = (kernel_size - 1) // 2) ### Output size must remain same
for _ in range(n_layers)])
self.dropout = nn.Dropout(dropout)
def forward(self, src):
#src = [batch size, src len]
batch_size = src.shape[0]
src_len = src.shape[1]
#create position tensor
pos = torch.arange(0, src_len)
.unsqueeze(0)
.repeat(batch_size, 1)
.to(self.device)
#pos = [[0, 1, 2, 3, ..., src len - 1]...]
#pos = [batch size, src len]
#embed tokens and positions
tok_embedded = self.tok_embedding(src)
pos_embedded = self.pos_embedding(pos)
#tok_embedded = pos_embedded = [batch size, src len, emb dim]
#combine embeddings by elementwise summing
embedded = self.dropout(tok_embedded + pos_embedded)
#embedded = [batch size, src len, emb dim]
#pass embedded through linear layer to convert from emb dim to hid dim
conv_input = self.emb2hid(embedded)
#conv_input = [batch size, src len, hid dim]
#permute for convolutional layer
conv_input = conv_input.permute(0, 2, 1)
#conv_input = [batch size, hid dim, src len]
#begin convolutional blocks...
for i, conv in enumerate(self.convs):
#pass through convolutional layer
conved = conv(self.dropout(conv_input))
#conved = [batch size, 2 * hid dim, src len]
#pass through GLU activation function
conved = F.glu(conved, dim = 1)
#conved = [batch size, hid dim, src len]
#apply residual connection
conved = (conved + conv_input) * self.scale
#conved = [batch size, hid dim, src len]
#set conv_input to conved for next loop iteration
conv_input = conved
#...end convolutional blocks
#permute and convert back to emb dim
conved = self.hid2emb(conved.permute(0, 2, 1))
#conved = [batch size, src len, emb dim]
#elementwise sum output (conved) and input (embedded) to be used for attention
combined = (conved + embedded) * self.scale
#combined = [batch size, src len, emb dim]
return conved, combined
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment