Skip to content

Instantly share code, notes, and snippets.

@attitudechunfeng
Created December 19, 2019 08:35
Show Gist options
  • Save attitudechunfeng/c162a5ed9b034be8f3f5800652af7c83 to your computer and use it in GitHub Desktop.
Save attitudechunfeng/c162a5ed9b034be8f3f5800652af7c83 to your computer and use it in GitHub Desktop.
def gen_probs(n, a, b):
probs = np.zeros(n+1)
for k in range(n+1):
probs[k] = binom(n, k) * beta(k + a, n - k + b) / beta(a, b)
return probs
class DCA(nn.Module):
def __init__(self, attn_dim):
super().__init__()
self.conv_static = nn.Conv1d(1, 8, padding=(21 - 1) // 2, kernel_size=21, bias=True)
self.U = nn.Linear(8, attn_dim, bias=True)
self.Wg = nn.Linear(attn_dim, attn_dim, bias=True)
self.Vg = nn.Linear(attn_dim, 8*21, bias=False)
self.T = nn.Linear(8, attn_dim, bias=True)
self.P = None
self.v = nn.Linear(attn_dim, 1, bias=False)
self.attention = None
def init_attention(self, encoder_seq_proj) :
b, t, c = encoder_seq_proj.size()
self.attention = torch.zeros(b, t).cuda()+1e-6
self.attention[:,0] = 1-(t-1)*1e-6
self.index_tensor = torch.ones((b,8,t)).long().cuda()
for i in range(b):
self.index_tensor[i,:,:] = i
def forward(self, encoder_seq_proj, encoder_seq, query, t):
if self.P is None:
self.P = torch.from_numpy(gen_probs(10, 0.1, 0.9).reshape(1,-1)).float().cuda()
if t == 0 : self.init_attention(encoder_seq_proj)
location = self.attention.unsqueeze(1)
batch_size = location.size(0)
win = location.size(-1)
static_part = self.U(self.conv_static(location).transpose(1, 2))
dym_kernel = self.Vg(nn.functional.tanh(self.Wg(query))).view(-1, 1, 21).contiguous()
dynamic_part = nn.functional.conv1d(location, dym_kernel, bias=None, stride=1, padding=10)
#dym_kernel = self.Vg(nn.functional.tanh(self.Wg(query))).view(-1, 8, 1, 21).contiguous()
#dynamic_part = torch.zeros((batch_size, 8, win)).cuda()
#for i in range(batch_size):
# dynamic_part[i:i+1,:,:] = nn.functional.conv1d(location[i:i+1,:,:], dym_kernel[i], bias=None, stride=1, padding=10)
#dynamic_part = self.T(dynamic_part.transpose(1,2))
dynamic_part = self.T(torch.gather(dynamic_part, 1, self.index_tensor).transpose(1,2))
u = self.v(torch.tanh(static_part + dynamic_part))
p_bias = torch.log(nn.functional.conv1d(location, self.P.view(1,1,-1).contiguous(), bias=None, stride=1, padding=(11-1)//2).transpose(1,2))
p_bias = torch.clamp(p_bias, min=-1e6)
u = u+p_bias
u = u.squeeze(-1)
# Smooth Attention
scores = F.softmax(u, dim=-1)
self.attention = scores
context = torch.bmm(scores.unsqueeze(1), encoder_seq)
context = context.squeeze(1)
return context, scores.unsqueeze(1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment