-
-
Save rajy4683/2e27a43e421b5e9728d458315b966d2b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Decoder(nn.Module): | |
def __init__(self, | |
output_dim, | |
emb_dim, | |
hid_dim, | |
n_layers, | |
kernel_size, | |
dropout, | |
trg_pad_idx, | |
device, | |
max_length = 100): | |
super().__init__() | |
self.kernel_size = kernel_size | |
self.trg_pad_idx = trg_pad_idx | |
self.device = device | |
self.scale = torch.sqrt(torch.FloatTensor([0.5])).to(device) | |
self.tok_embedding = nn.Embedding(output_dim, emb_dim) | |
self.pos_embedding = nn.Embedding(max_length, emb_dim) | |
self.emb2hid = nn.Linear(emb_dim, hid_dim) | |
self.hid2emb = nn.Linear(hid_dim, emb_dim) | |
self.attn_hid2emb = nn.Linear(hid_dim, emb_dim) | |
self.attn_emb2hid = nn.Linear(emb_dim, hid_dim) | |
self.fc_out = nn.Linear(emb_dim, output_dim) | |
self.convs = nn.ModuleList([nn.Conv1d(in_channels = hid_dim, | |
out_channels = 2 * hid_dim, ### GLU will reduce the dimensions | |
kernel_size = kernel_size) | |
for _ in range(n_layers)]) | |
self.dropout = nn.Dropout(dropout) | |
def calculate_attention(self, embedded, conved, encoder_conved, encoder_combined): | |
#embedded = [batch size, trg len, emb dim] | |
#conved = [batch size, hid dim, trg len] | |
#encoder_conved = encoder_combined = [batch size, src len, emb dim] | |
#permute and convert back to emb dim | |
conved_emb = self.attn_hid2emb(conved.permute(0, 2, 1)) | |
#conved_emb = [batch size, trg len, emb dim] | |
combined = (conved_emb + embedded) * self.scale | |
#combined = [batch size, trg len, emb dim] | |
energy = torch.matmul(combined, encoder_conved.permute(0, 2, 1)) | |
#energy = [batch size, trg len, src len] | |
attention = F.softmax(energy, dim=2) | |
#attention = [batch size, trg len, src len] | |
attended_encoding = torch.matmul(attention, encoder_combined) | |
#attended_encoding = [batch size, trg len, emd dim] | |
#convert from emb dim -> hid dim | |
attended_encoding = self.attn_emb2hid(attended_encoding) | |
#attended_encoding = [batch size, trg len, hid dim] | |
#apply residual connection | |
attended_combined = (conved + attended_encoding.permute(0, 2, 1)) * self.scale | |
#attended_combined = [batch size, hid dim, trg len] | |
return attention, attended_combined | |
def forward(self, trg, encoder_conved, encoder_combined): | |
#trg = [batch size, trg len] | |
#encoder_conved = encoder_combined = [batch size, src len, emb dim] | |
batch_size = trg.shape[0] | |
trg_len = trg.shape[1] | |
#create position tensor | |
pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device) | |
#pos = [batch size, trg len] | |
#embed tokens and positions | |
tok_embedded = self.tok_embedding(trg) | |
pos_embedded = self.pos_embedding(pos) | |
#tok_embedded = [batch size, trg len, emb dim] | |
#pos_embedded = [batch size, trg len, emb dim] | |
#combine embeddings by elementwise summing | |
embedded = self.dropout(tok_embedded + pos_embedded) | |
#embedded = [batch size, trg len, emb dim] | |
#pass embedded through linear layer to go through emb dim -> hid dim | |
conv_input = self.emb2hid(embedded) | |
#conv_input = [batch size, trg len, hid dim] | |
#permute for convolutional layer | |
conv_input = conv_input.permute(0, 2, 1) | |
#conv_input = [batch size, hid dim, trg len] | |
batch_size = conv_input.shape[0] | |
hid_dim = conv_input.shape[1] | |
for i, conv in enumerate(self.convs): | |
#apply dropout | |
conv_input = self.dropout(conv_input) | |
#need to pad so decoder can't "cheat" | |
padding = torch.zeros(batch_size, | |
hid_dim, | |
self.kernel_size - 1).fill_(self.trg_pad_idx).to(self.device) | |
padded_conv_input = torch.cat((padding, conv_input), dim = 2) | |
#padded_conv_input = [batch size, hid dim, trg len + kernel size - 1] | |
#pass through convolutional layer | |
conved = conv(padded_conv_input) | |
#conved = [batch size, 2 * hid dim, trg len] | |
#pass through GLU activation function | |
conved = F.glu(conved, dim = 1) | |
#conved = [batch size, hid dim, trg len] | |
#calculate attention | |
attention, conved = self.calculate_attention(embedded, | |
conved, | |
encoder_conved, | |
encoder_combined) | |
#attention = [batch size, trg len, src len] | |
#apply residual connection | |
conved = (conved + conv_input) * self.scale | |
#conved = [batch size, hid dim, trg len] | |
#set conv_input to conved for next loop iteration | |
conv_input = conved | |
conved = self.hid2emb(conved.permute(0, 2, 1)) | |
#conved = [batch size, trg len, emb dim] | |
output = self.fc_out(self.dropout(conved)) | |
#output = [batch size, trg len, output dim] | |
return output, attention |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment