Skip to content

Instantly share code, notes, and snippets.

@AranKomat
Created March 17, 2020 04:12
Show Gist options
  • Save AranKomat/83efd8035e740b4ec90de0228f2e8ea6 to your computer and use it in GitHub Desktop.
Save AranKomat/83efd8035e740b4ec90de0228f2e8ea6 to your computer and use it in GitHub Desktop.
def shift_(x):
# x = [*, t_q, t_k]
zero_pad = torch.zeros(*x.size()[:-1], x.size(-2), device=x.device, dtype=x.dtype)
x = torch.cat([x, zero_pad], -1)
l = x.size(-1)
x = x.view(*x.size()[:-2], -1)
zero_pad = torch.zeros(*x.size()[:-1], -x.size(-1) % (l - 1), device=x.device, dtype=x.dtype)
return torch.cat([x, zero_pad], -1).view(*x.size()[:-1], -1, l - 1)
def shift(x):
t_q = x.size()[-2]
return shift_(x)[..., :t_q, t_q - 1:]
class LAttention(nn.Module):
def __init__(self, config):
super(LAttention, self).__init__()
self.config = config
std = math.sqrt(1 / config.hidden_dim)
self.window_size = config.window_size # say 64
self.R = nn.Parameter(torch.zeros(2*self.window_size, config.num_heads, config.hidden_dim // config.num_heads,
device=self.config.device).normal_(0, std))
def forward(self, q, k, v, decoding=False, **kwargs):
b_q, h_q, t_q, dim_q = list(q.size())
b_k, h_k, t_k, dim_k = list(k.size())
tgt_length = self.window_size
q = q.view(b_q, h_q, t_q // tgt_length, tgt_length, dim_q) #
k = k.view(b_k, h_k, t_k // tgt_length, tgt_length, dim_k) #
v = v.view(b_k, h_k, t_k // tgt_length, tgt_length, dim_k) #
if self.config.share_qk:
k = F.normalize(k,dim=-1)
def f(x):
x_extra = F.pad(x[:, :, :-1, ...], pad=(0,0,0,0,1,0))
return torch.cat([x_extra, x], dim=3)
k = f(k)
v = f(v)
k_part = torch.einsum('bhcqd,bhckd->bhcqk', q, k)
tmp = torch.einsum('bhcqd,khd->bhcqk', q, self.R)
wr_part = shift(tmp)
qk = k_part + wr_part
qk *= dim_q ** -0.5
pre_mask = torch.ones(tgt_length, tgt_length*2, device=self.config.device).byte().triu_(tgt_length + 1)
mask = float_half(self.config, pre_mask.float() * (-1e9))
if self.config.share_qk:
pre_mask = torch.ones(tgt_length, tgt_length*2, device=self.config.device).byte().triu_(tgt_length).tril_(tgt_length)
mask += float_half(self.config, pre_mask.float() * (-1e3))
qk += mask
sm_qk = F.softmax(qk, dim=-1)
sm_qk = F.dropout(sm_qk, p=self.config.dropout_prob, training=self.training)
o = torch.einsum('bhcqk,bhckd->bhcqd', sm_qk, v)
o = o.view(b_q, h_q, t_q, dim_q)
return o, kwargs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment