Skip to content

Instantly share code, notes, and snippets.

@acetylSv
Created March 18, 2019 11:25
Show Gist options
  • Save acetylSv/9dcff15bc0e895c0190c5942b573c28b to your computer and use it in GitHub Desktop.
Save acetylSv/9dcff15bc0e895c0190c5942b573c28b to your computer and use it in GitHub Desktop.
Location aware attention (pytorch implementation of https://arxiv.org/pdf/1506.07503.pdf)
import torch
import torch.nn as nn
class LocAwareAttnLayer(nn.Module):
'''
implementation of: https://arxiv.org/pdf/1506.07503.pdf
'''
def __init__(self, dec_hidden_dim, enc_feat_dim, conv_dim, attn_dim, smoothing=False):
super(LocAwareAttnLayer, self).__init__()
self.attn_dim = attn_dim
self.dec_hidden_dim = dec_hidden_dim
self.conv_dim = conv_dim
self.conv = nn.Conv1d(in_channels=1, out_channels=self.conv_dim, kernel_size=3, padding=1)
self.W = nn.Linear(dec_hidden_dim, attn_dim, bias=False)
self.V = nn.Linear(enc_feat_dim, attn_dim, bias=False)
self.U = nn.Linear(conv_dim, attn_dim, bias=False)
self.b = nn.Parameter(torch.rand(attn_dim))
self.w = nn.Linear(attn_dim, 1, bias=False)
self.smoothing = smoothing
self.tanh = nn.Tanh()
self.softmax = nn.Softmax(dim=-1)
def forward(self, decoder_state, enc_feat, last_align):
# conv_feat: [batch, enc_feat_len, conv_dim]
conv_feat = torch.transpose(self.conv(last_align.unsqueeze(dim=1)), 1, 2)
# energy: [batch, enc_feat_len]
energy = self.w(self.tanh(
self.W(decoder_state)
+ self.V(enc_feat)
+ self.U(conv_feat)
+ self.b
)).squeeze(dim=-1)
if self.smoothing:
energy = torch.sigmoid(energy)
attn_weight = torch.div(energy, energy.sum(dim=-1).unsqueeze(dim=-1))
else:
attn_weight = self.softmax(energy)
context = torch.bmm(attn_weight.unsqueeze(dim=1), enc_feat).squeeze(dim=1)
return attn_weight, context
@sooftware
Copy link

what is 'last_align' ??
i think that 'last_align' means previous 'attn_weight'. is it right??

@30stomercury
Copy link

what is 'last_align' ??
i think that 'last_align' means previous 'attn_weight'. is it right??

Yes, it is.

@vishhvak
Copy link

vishhvak commented Apr 28, 2022

The previous attn_weight is the one calculated at the previous time step of the decoder right? Would it make sense if i just stored the
attention weight i calculate as a class parameter and keep updating it each time, after the first time step?

However, I don't understand how previous attention is added - won't it have a different sequence length?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment