Skip to content

Instantly share code, notes, and snippets.

@jayleicn
Created February 20, 2018 16:40
Show Gist options
  • Save jayleicn/4f935df94d2368c7cb459f9294bd3175 to your computer and use it in GitHub Desktop.
Save jayleicn/4f935df94d2368c7cb459f9294bd3175 to your computer and use it in GitHub Desktop.
bidaf attention layer
import torch
import torch.nn as nn
class BidafAttnModule(nn.Module):
def __init__(self, hidden_size):
super(BidafAttnModule, self).__init__()
self.fc = nn.Sequential(
nn.Linear(hidden_size, hidden_size, bias=False),
nn.ReLU()
)
self.fc_final = nn.Sequential(
nn.Linear(4*hidden_size, hidden_size),
nn.ReLU()
)
self.gru = nn.GRU(4*hidden_size, hidden_size)
# def forward(self, contexts, questions):
# """
# contexts -> #batch, #sen, #token, #hidden -> #batch, #sen * #token, #hidden
# questions -> #batch, #seq, #hidden
# G -> #batch, #sen * #token, 4*#hidden
# """
# batch_num, sen_num, token_num, hidden_size = contexts.size()
# contexts = contexts.view(batch_num, sen_num*token_num, hidden_size)
# S = self.get_similarity_matrix(contexts, questions)
# # attn_q = self.get_context_aware_query(S, questions)
# attn_c = self.get_query_aware_context(S, contexts)
# # attn_c = attn_c.expand_as(attn_q)
# # G = torch.cat([contexts, attn_q, contexts*attn_q, contexts*attn_c], -1)
# return attn_c
def forward(self, contexts, questions):
"""
contexts -> #batch, #sen, #token, #hidden -> #batch, #sen * #token, #hidden
questions -> #batch, #seq, #hidden
attn_c -> #batch, 1, #hidden
attn_q -> #batch, #sen*#token, #hidden
G -> #batch, #sen * #token, 4*#hidden
"""
batch_num, sen_num, token_num, hidden_size = contexts.size()
contexts = contexts.view(batch_num, sen_num*token_num, hidden_size)
S = self.get_similarity_matrix(contexts, questions)
attn_q = self.get_context_aware_query(S, questions)
attn_c = self.get_query_aware_context(S, contexts)
attn_c = attn_c.expand_as(attn_q)
G = torch.cat([contexts, attn_q, contexts*attn_q, contexts*attn_c], -1)
G, _ = self.gru(G.view(batch_num*sen_num, token_num, -1))
G = torch.max(G, 1)[0]
# G = self.fc_final(G)
G = G.view(batch_num, sen_num, hidden_size)
return G
def get_similarity_matrix(self, contexts, questions):
"""
contexts -> #batch, #sen * #token, #hidden
questions -> #batch, #seq, #hidden
S -> #batch, #sen*#token, #seq
S_tj = h*W*u
"""
batch_num, _, hidden_size = contexts.size()
questions = self.fc(questions.contiguous().view(-1, hidden_size)).view(batch_num, -1, hidden_size)
S = torch.bmm(contexts, questions.transpose(1,2))
return S
def get_context_aware_query(self, S, questions):
"""
S -> #batch, #sen*#token, #seq
score -> #batch, #sen*#token, #seq
questions -> #batch, #seq, #hidden
context_aware_questions -> #batch, #sen*#token, #hidden
"""
score = F.softmax(S, dim=2)
context_aware_questions = torch.bmm(score, questions)
return context_aware_questions
def get_query_aware_context(self, S, contexts):
"""
S -> #batch, #sen*#token, #seq
S_max -> #batch, #sen*#token
score -> #batch, #sen*#token -> #batch, 1, #sen*#token
contexts -> #batch, #sen * #token, #hidden
query_aware_contexts -> #batch, 1, #hidden
"""
S_max = torch.max(S, 2)[0]
score = F.softmax(S_max, dim=1)
query_aware_contexts = torch.bmm(score.unsqueeze(1), contexts)
return query_aware_contexts
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment