Skip to content

Instantly share code, notes, and snippets.

@rpryzant
Last active Dec 5, 2018
Embed
What would you like to do?
"""
Usage (for our feedforward context):
make sure you initialize the layer with
score_fn='bahdanau'
and then when you use the module in your forward()
method, you can feed it a vector of zeros for your query:
query = torch.zeros(rnn_outputs[:, 0, :].shape)
src_summary, _, attn_probs = self.attn_mechanism(
query=query,
keys=rnn_outputs,
values=rnn_outputs,
mask=srcmask)
Last, the "srcmask" parameter is used to mask out
all the pad tokens in your input. If "lens" is
a list of sequence lengths in your batch, you can
make a srcmask with
mask = [
([1] * l) + ([0] * (max_len - l))
for l in lens
]
"""
class BilinearAttention(nn.Module):
""" bilinear attention layer: score(H_j, q) = H_j^T W_a q
(where W_a = self.in_projection)
"""
def __init__(self, hidden, score_fn='dot'):
super(BilinearAttention, self).__init__()
self.query_in_projection = nn.Linear(hidden, hidden)
self.key_in_projection = nn.Linear(hidden, hidden)
self.softmax = nn.Softmax()
self.out_projection = nn.Linear(hidden * 2, hidden)
self.tanh = nn.Tanh()
self.score_fn = self.dot
if score_fn == 'bahdanau':
self.v_att = nn.Linear(hidden, 1, bias=False)
self.score_tanh = nn.Tanh()
self.score_fn = self.bahdanau
def forward(self, query, keys, mask=None, values=None):
"""
query: [batch, hidden]
keys: [batch, len, hidden]
values: [batch, len, hidden] (optional, if none will = keys)
mask: [batch, len] mask key-scores
compare query to keys, use the scores to do weighted sum of values
if no value is specified, then values = keys
"""
att_keys = self.key_in_projection(keys)
if values is None:
values = att_keys
# [Batch, Hidden, 1]
att_query = self.query_in_projection(query)
# [Batch, Source length]
attn_scores = self.score_fn(att_keys, att_query)
if mask is not None:
attn_scores = attn_scores.masked_fill(mask, -float('inf'))
attn_probs = self.softmax(attn_scores)
# [Batch, 1, source length]
attn_probs_transposed = attn_probs.unsqueeze(1)
# [Batch, hidden]
weighted_context = torch.bmm(attn_probs_transposed, values).squeeze(1)
context_query_mixed = torch.cat((weighted_context, query), 1)
context_query_mixed = self.tanh(self.out_projection(context_query_mixed))
return weighted_context, context_query_mixed, attn_probs
def dot(self, keys, query):
"""
keys: [B, T, H]
query: [B, H]
"""
return torch.bmm(keys, query.unsqueeze(2)).squeeze(2)
def bahdanau(self, keys, query):
"""
keys: [B, T, H]
query: [B, H]
"""
return self.v_att(self.score_tanh(keys + query.unsqueeze(1))).squeeze(2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment