Created
March 18, 2019 11:25
-
-
Save acetylSv/9dcff15bc0e895c0190c5942b573c28b to your computer and use it in GitHub Desktop.
Location aware attention (pytorch implementation of https://arxiv.org/pdf/1506.07503.pdf)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
class LocAwareAttnLayer(nn.Module): | |
''' | |
implementation of: https://arxiv.org/pdf/1506.07503.pdf | |
''' | |
def __init__(self, dec_hidden_dim, enc_feat_dim, conv_dim, attn_dim, smoothing=False): | |
super(LocAwareAttnLayer, self).__init__() | |
self.attn_dim = attn_dim | |
self.dec_hidden_dim = dec_hidden_dim | |
self.conv_dim = conv_dim | |
self.conv = nn.Conv1d(in_channels=1, out_channels=self.conv_dim, kernel_size=3, padding=1) | |
self.W = nn.Linear(dec_hidden_dim, attn_dim, bias=False) | |
self.V = nn.Linear(enc_feat_dim, attn_dim, bias=False) | |
self.U = nn.Linear(conv_dim, attn_dim, bias=False) | |
self.b = nn.Parameter(torch.rand(attn_dim)) | |
self.w = nn.Linear(attn_dim, 1, bias=False) | |
self.smoothing = smoothing | |
self.tanh = nn.Tanh() | |
self.softmax = nn.Softmax(dim=-1) | |
def forward(self, decoder_state, enc_feat, last_align): | |
# conv_feat: [batch, enc_feat_len, conv_dim] | |
conv_feat = torch.transpose(self.conv(last_align.unsqueeze(dim=1)), 1, 2) | |
# energy: [batch, enc_feat_len] | |
energy = self.w(self.tanh( | |
self.W(decoder_state) | |
+ self.V(enc_feat) | |
+ self.U(conv_feat) | |
+ self.b | |
)).squeeze(dim=-1) | |
if self.smoothing: | |
energy = torch.sigmoid(energy) | |
attn_weight = torch.div(energy, energy.sum(dim=-1).unsqueeze(dim=-1)) | |
else: | |
attn_weight = self.softmax(energy) | |
context = torch.bmm(attn_weight.unsqueeze(dim=1), enc_feat).squeeze(dim=1) | |
return attn_weight, context |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The previous attn_weight is the one calculated at the previous time step of the decoder right? Would it make sense if i just stored the
attention weight i calculate as a class parameter and keep updating it each time, after the first time step?
However, I don't understand how previous attention is added - won't it have a different sequence length?