Skip to content

Instantly share code, notes, and snippets.

@AranKomat
Last active March 2, 2020 05:16
Show Gist options
  • Save AranKomat/87a4ddcb293bd81ad1b2b5a8688df98c to your computer and use it in GitHub Desktop.
Save AranKomat/87a4ddcb293bd81ad1b2b5a8688df98c to your computer and use it in GitHub Desktop.
import torch
from torch.nn import Parameter
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import math
import numpy as np
from torch.autograd import Function
try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ImportError:
import nn.LayerNorm as LayerNorm
from copy import copy
from time import time
#from torch_scatter import scatter
def float_half(config, x):
if config.fp16:
return x.half()
else:
return x.float()
def initialize(m):
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d):
init.xavier_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
def positional_embedding(config):
''' Sinusoid position encoding table '''
n_position = 10000
d_hid = config.hidden_dim
def cal_angle(position, hid_idx):
return position / np.power(n_position, 2 * (hid_idx // 2) / d_hid)
def get_posi_angle_vec(position):
return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return float_half(config, torch.tensor(sinusoid_table, device=config.device))
class Embedding(nn.Module):
def __init__(self, config):
super(Embedding, self).__init__()
self.config = config
self.hd = config.hidden_dim
if config.adaptive:
self.emb = AdaptiveInput(config, self.hd, config.vocab_size, config.cutoffs)
else:
self.emb = nn.Embedding(config.vocab_size, self.hd)
def forward(self, x, decoding=False):
if self.config.adaptive:
if decoding: # TODO: incomplete here
return
else:
return self.emb(x.contiguous().view(1, -1)).squeeze().view(x.size(0), -1, self.hd)
else:
return self.emb(x)
class Softmax(nn.Module):
def __init__(self, config):
super(Softmax, self).__init__()
self.config = config
self.hd = config.hidden_dim
if config.adaptive:
# self.softmax = AdaptiveLogSoftmaxWithLoss(self.hd, config.vocab_size, config.cutoffs)
self.softmax = nn.AdaptiveLogSoftmaxWithLoss(self.hd, config.vocab_size, config.cutoffs)
else:
self.linear = nn.Linear(self.hd, config.vocab_size, bias=False)
def forward(self, x, tgt=None, decoding=False):
if self.config.adaptive:
# print(x.size(),tgt.size())
# tgt = tgt[:, 70:].contiguous()
tgt = tgt.contiguous().view(-1)
if decoding: # TODO: incomplete here
return self.softmax.log_prob(x, tgt)
else:
# x = x[:,70:].contiguous()
return self.softmax(x.view(-1, self.hd), tgt)
else:
x = self.linear(x)
x = F.log_softmax(x, -1)
return x
class ED(nn.Module):
def __init__(self, config):
super(ED, self).__init__()
self.config = config
c1 = copy(config)
c2 = copy(config)
c1.encoder = True
c2.attend = True
# c2.cached = True
self.emb = Embedding(config)
self.e = Transformer(c1)
self.d = Transformer(c2)
self.softmax = Softmax(config)
def encoding(self, encoder_inp):
return self.e(self.emb(encoder_inp))[1]
def forward(self, encoder_cache, inp, cache=None):
x, cache, l = self.d(self.emb(inp), cache=cache, encoder_cache=encoder_cache)
x = self.softmax(x)
return x, cache, l
class Decoder(nn.Module):
def __init__(self, config):
super(Decoder, self).__init__()
self.config = config
self.td = Transformer(config)
self.softmax = Softmax(config)
self.emb = Embedding(config)
def forward(self, inp, cache=None, decoding=False, cur_time=0, tgt=None, **kwargs):
x = inp
x = self.emb(x)
x, c, kwargs = self.td(x, cache=cache, decoding=decoding, cur_time=cur_time, **kwargs)
return self.softmax(x, tgt=tgt)[1], c, kwargs
class Transformer(nn.Module):
def __init__(self, config, weights=None):
super(Transformer, self).__init__()
self.config = config
std = math.sqrt(1 / config.hidden_dim)
self.u = nn.Parameter(
torch.zeros(1, config.num_heads, 1, config.hidden_dim // config.num_heads,
device=self.config.device).normal_(0, std))
self.v = nn.Parameter(
torch.zeros(1, config.num_heads, 1, config.hidden_dim // config.num_heads,
device=self.config.device).normal_(0, std))
self.R = positional_embedding(config)
shared_params = [self.R, self.u, self.v]
self.blocks = nn.ModuleList([Block(config, shared_params, i)
for i in range(config.depth)])
def forward(self, x, cache=None, cur_time=0, decoding=False, **kwargs):
total_loss = 0
cache = [None for _ in range(self.config.depth)] if cache is None else cache
new_cache = [None for _ in range(self.config.depth)]
if not self.config.rel_enc:
x += self.R[:x.size(1)]
for i in range(self.config.depth):
x = F.dropout(x, p=self.config.dropout_prob, training=self.training)
x, new_cache[i], kwargs = self.blocks[i](x, cache[i], decoding=decoding, **kwargs)
return x, new_cache, kwargs
def shift_(x):
# x = [*, t_q, t_k]
zero_pad = torch.zeros(*x.size()[:-1], x.size(-2), device=x.device, dtype=x.dtype)
x = torch.cat([x, zero_pad], -1)
l = x.size(-1)
x = x.view(*x.size()[:-2], -1)
zero_pad = torch.zeros(*x.size()[:-1], -x.size(-1) % (l - 1), device=x.device, dtype=x.dtype)
return torch.cat([x, zero_pad], -1).view(*x.size()[:-1], -1, l - 1)
def shift(x):
t_q = x.size()[-2]
return shift_(x)[..., :t_q, t_q - 1:]
class MultiHeadAttention(nn.Module):
def __init__(self, config, shared_params=None, idx=0):
super(MultiHeadAttention, self).__init__()
self.config = config
self.num_heads = config.num_heads
self.num_heads_k = 1 if config.multi_query else config.num_heads
self.lin_q = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False)
if config.share_qk:
self.lin_v = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False)
elif config.multi_query:
self.lin_kv = nn.Linear(config.hidden_dim, (config.hidden_dim * 2) // config.num_heads, bias=False)
else:
self.lin_kv = nn.Linear(config.hidden_dim, config.hidden_dim * 2, bias=False)
self.linear_out = nn.Linear(config.hidden_dim, config.hidden_dim)
if self.config.mode in [3, 4]:
self.attention = Attention(config, shared_params=shared_params, idx=idx)
else:
self.attention = LSHAttention(config, shared_params=shared_params, idx=idx)
assert (config.hidden_dim % config.num_heads == 0)
def forward(self, inp, cache=None, decoding=False, **kwargs):
ctxt_length = self.config.ctxt_length if self.training else self.config.ctxt_length_test
inp_ = inp[:, ctxt_length:] if self.config.recurrent else inp
if self.config.share_qk:
q = self.lin_q(inp)
k = q / (torch.norm(q, dim=-1, keepdim=True) + 1e-6)
q = q[:, ctxt_length:]
v = self.lin_v(inp)
else:
q = self.lin_q(inp_)
kv = self.lin_kv(inp)
if self.config.efficient_xl: # TODO: check this
kv = torch.cat([cache, kv], 1)
cache = kv[:, -ctxt_length:].detach()
k, v = torch.split(kv, self.config.hidden_dim // 8, dim=-1) if self.config.multi_query else torch.split(kv,
self.config.hidden_dim,
dim=-1)
if decoding:
if cache[0] is None:
cache = [k, v]
else:
k = torch.cat([cache[0], k], 1)
v = torch.cat([cache[1], v], 1)
cache = [k, v]
batch, length, dim = list(q.size())
q = q.view(batch, q.size(1), self.num_heads, -1).transpose(1, 2)
k = k.view(batch, k.size(1), self.num_heads_k, -1).transpose(1, 2)
v = v.view(batch, v.size(1), self.num_heads_k, -1).transpose(1, 2)
out, kwargs = self.attention(q, k, v, decoding=decoding, **kwargs)
out = out.transpose(1, 2).contiguous().view(batch, length, dim)
out = self.linear_out(out)
return out, cache, kwargs
class Attention(nn.Module):
def __init__(self, config, shared_params=None, idx=0):
super(Attention, self).__init__()
self.config = config
self.R = shared_params[0]
if self.config.rel_enc:
self.w_r = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False)
self.u = shared_params[1]
self.v = shared_params[2]
self.idx = idx
std = 1 / config.hidden_dim
self.means1 = torch.zeros(config.batch_size, config.num_heads, config.num_clusters,
config.hidden_dim // config.num_heads,
device=self.config.device).normal_(0, std).half()
self.means1 = self.means1 / self.means1.norm(dim=-1, keepdim=True)
self.means2 = torch.zeros(config.num_heads, config.num_clusters, config.hidden_dim // config.num_heads,
device=self.config.device).normal_(0, std).half()
self.means2 = self.means2 / self.means2.norm(dim=-1, keepdim=True)
def R_(self, t_k):
return self.R[:t_k].flip(0)
def forward(self, q, k, v, decoding=False, **kwargs):
aux = 0
b_q, h_q, t_q, dim_q = list(q.size())
b_k, h_k, t_k, dim_k = list(k.size())
if self.idx == 3 and False:
# random_rotations = k.new_zeros(h_k, dim_k, self.config.num_clusters // 2).normal_(0, 1)
random_rotations = torch.randn((h_k, dim_k, self.config.num_clusters // 2), device=self.config.device,
dtype=k.dtype)
rotated_vecs = torch.einsum('bhkd,hdc->bhkc', k, random_rotations)
rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1)
buckets = torch.argmax(rotated_vecs, dim=-1)
one_hot = F.one_hot(buckets, num_classes=self.config.num_clusters).half()
print(one_hot[0, 0].long().sum(0))
rotated_vecs = torch.einsum('bhkd,hcd->bhkc', k, self.means2)
buckets = torch.argmax(rotated_vecs, dim=-1)
one_hot = F.one_hot(buckets, num_classes=self.config.num_clusters).half()
print(one_hot[0, 0].long().sum(0))
pre_means = torch.einsum('bhkd,bhkc->bhkcd', k, one_hot).transpose(0, 1).contiguous().view(h_k, -1,
self.config.num_clusters,
dim_k).float().sum(
1)
self.means2 = (pre_means / pre_means.float().norm(dim=-1, keepdim=True)).half().detach()
pre_mask = 1 - torch.einsum('bhkc, bhlc -> bhkl', one_hot, one_hot)
mask = float_half(self.config, pre_mask.float() * (-1e9))
if self.config.rel_enc:
k_part = torch.einsum('bhqd,bhkd->bhqk', q + self.u, k) if not self.config.multi_query else torch.einsum(
'bhqd,bkd->bhqk', q + self.u, k.squeeze(1))
tmp = torch.einsum('bhqd,khd->bhqk', q + self.v, self.w_r(self.R_(t_k)).view(-1, h_q, dim_q))
wr_part = shift(tmp)
qk = k_part + wr_part
else:
qk = torch.einsum('bhqd,bhkd->bhqk', q, k)
qk *= dim_q ** -0.5
pre_mask = torch.ones(t_q, t_k, device=self.config.device).byte().triu_(t_k - t_q + 1)
mask = float_half(self.config, pre_mask.float() * (-1e9))
if self.config.share_qk:
pre_mask = torch.ones(t_q, t_k, device=self.config.device).byte().triu_(t_k - t_q).tril_(t_k - t_q)
mask += float_half(self.config, pre_mask.float() * (-1e3))
qk += mask
if self.config.mode == 3: # mode 3 = topk, mode 4 = full
qk = torch.where(qk >= qk.topk(32, dim=-1)[0][..., -1:], qk,
float_half(self.config, qk.new_ones(1).float() * (-1e9)))
sm_qk = F.softmax(qk, dim=-1)
sm_qk = F.dropout(sm_qk, p=self.config.dropout_prob, training=self.training)
o = torch.einsum('bhqk,bhkd->bhqd', sm_qk, v) if not self.config.multi_query else torch.einsum(
'bhqk,bkd->bhqd',
sm_qk,
v.squeeze(1))
return o, kwargs
def batched_index_select(values, indices):
last_dim = values.shape[-1]
return values.gather(2, indices[:, :, :, None].expand(-1, -1, -1, last_dim))
class SparseAttention(nn.Module):
def __init__(self, config, shared_params=None, idx=0):
super(SparseAttention, self).__init__()
self.config = config
self.num_heads = config.num_heads
self.num_heads_v = 1 if config.multi_query else config.num_heads
self.lin_q = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False)
self.lin_v = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False)
self.linear_out = nn.Linear(config.hidden_dim, config.hidden_dim)
std = 1 / np.sqrt(config.hidden_dim)
if config.all:
self.seqlen = 2*self.config.tgt_length
assert self.config.num_weights * self.config.num_clusters == self.seqlen
else:
self.seqlen = self.config.tgt_length
if config.rel_enc:
self.rel_enc = nn.Parameter(torch.zeros(config.num_heads, self.seqlen, device=config.device).normal_(0, std))
if config.mode == 2:
self.means = torch.zeros(config.num_heads, config.num_clusters, config.hidden_dim // config.num_heads,
device=self.config.device).normal_(0, std).half()
self.means = F.normalize(self.means, dim=-1)
if config.hierarchical:
self.means2 = torch.zeros(config.num_heads, config.num_clusters2, config.hidden_dim // config.num_heads,
device=self.config.device).normal_(0, std).half()
self.means2 = F.normalize(self.means2, dim=-1)
NotImplementedError
if config.all:
self.M_k = nn.Parameter(torch.zeros(config.num_heads, config.num_weights*config.num_clusters,
config.config.hidden_dim // config.num_heads,
device=self.config.device).normal_(0, std))
self.M_v = nn.Parameter(torch.zeros(config.num_heads, config.num_weights*config.num_clusters,
config.config.hidden_dim // config.num_heads,
device=self.config.device).normal_(0, std))
def forward(self, inp, cache=None, decoding=False, **kwargs):
ctxt_length = self.config.ctxt_length if self.training else self.config.ctxt_length_test
inp_ = inp[:, ctxt_length:] if self.config.recurrent else inp
q = self.lin_q(inp_)
v = self.lin_v(inp)
batch, length, dim = list(q.size())
q = q.view(batch, q.size(1), self.num_heads, -1).transpose(1, 2)
v = v.view(batch, v.size(1), self.num_heads_v, -1).transpose(1, 2)
out, kwargs = self.attention(q, v, decoding=decoding, **kwargs)
out = out.transpose(1, 2).contiguous().view(batch, length, dim)
out = self.linear_out(out)
return out, cache, kwargs
def attention(self, qk, v, decoding=False, **kwargs):
b_q, h_q, t_q, dim_q = list(qk.size())
q = qk
k = F.normalize(qk, dim=-1)
with torch.no_grad():
if self.config.mode == 1: # lsh
random_rotations = k.new_zeros(h_q, dim_q, self.config.num_clusters // 2).normal_(0, 1)
rotated_vecs = torch.einsum('bhkd,hdc->bhkc', k, random_rotations)
rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1)
if self.config.mode == 2: # k-means
rotated_vecs = torch.einsum('bhkd,hcd->bhkc', k, self.means)
buckets = torch.argmax(rotated_vecs, dim=-1) #bhk
#print(torch.bincount(buckets.view(-1))[:10])
if self.config.mode == 2: #[b,h,k,d] x [b,h,k] -> [b,h,c,d]
k_float = k.float()
src = k_float.new_zeros(b_q, h_q, self.config.num_clusters, dim_q)
src.scatter_add_(2, buckets.unsqueeze(-1).expand(-1,-1,-1, dim_q), k_float)
#pre_means = scatter(k, buckets, dim=2, reduce='mean') # mean version of scatter_add
#self.means = (src / src.float().norm(dim=-1, keepdim=True)).half()
self.means = F.normalize(src.mean(0), dim=-1).half()
del src
return self.attention_(q, k, v, buckets, **kwargs)
def attention_(self, q, k, v, buckets, **kwargs):
with torch.no_grad():
seqlen = self.seqlen
cluster_size = seqlen // self.config.num_clusters
if self.config.all:
cluster_size /= 2
q = torch.cat([self.M_k.unsqueeze(0).expand(q.size(0), -1, -1, -1), q], -2)
k = torch.cat([self.M_k.unsqueeze(0).expand(k.size(0), -1, -1, -1), k], -2)
v = torch.cat([self.M_v.unsqueeze(0).expand(v.size(0), -1, -1, -1), v], -2)
M_buckets = torch.arange(self.config.num_clusters, device=self.config.device).unsqueeze(0).expand(self.config.num_weights, -1).view(-1).unsqueeze(0).unsqueeze(0)
buckets = torch.cat([M_buckets.expand(q.size(0), q.size(1), -1), buckets], -1)
batch_size, h_q, h_v = q.size(0), q.size(1), v.size(1)
ticker = torch.arange(seqlen, device=self.config.device)
buckets_and_t = seqlen * buckets + ticker
# Hash-based sort ("s" at the start of variable names means "sorted")
_, sticker = buckets_and_t.sort(-1)
_, undo_sort = sticker.sort(-1)
st = sticker
#TODO: see if view gather faster than naive gather
sq = batched_index_select(q, st) # q [b, h, k, d], st [b, h, k]
sk = batched_index_select(k, st)
sv = batched_index_select(v, st)
# Split off a "bin" axis so that attention only occurs within chunks.
bq_t = bkv_t = torch.reshape(st, (batch_size, h_q, self.config.num_clusters, cluster_size)) #[b, h, nc, cs]
bq = torch.reshape(sq, (batch_size, h_q, self.config.num_clusters, cluster_size, sq.shape[-1])) #[b, h, nc, cs, d]
bk = torch.reshape(sk, (batch_size, h_q, self.config.num_clusters, cluster_size, sk.shape[-1]))
bv = torch.reshape(sv, (batch_size, h_v, self.config.num_clusters, cluster_size, sv.shape[-1]))
# Allow each chunk to attend within itself, and also one chunk back. Chunk
# boundaries might occur in the middle of a sequence of items from the
# same bucket, so this increases the chances of attending to relevant items.
def look_one_back(x):
if self.config.all:
x_extra = torch.cat([x[:, :, -1:, ...], x[:, :, :-1, ...]], dim=2)
x_extra2 = torch.cat([x[:, :, -2:, ...], x[:, :, :-2, ...]], dim=2)
return torch.cat([x_extra2, x_extra, x], dim=3)
else:
x_extra = torch.cat([x[:, :, -1:, ...], x[:, :, :-1, ...]], dim=2)
return torch.cat([x_extra, x], dim=3)
bk = look_one_back(bk)
bv = look_one_back(bv)
bkv_t = look_one_back(bkv_t)
# Dot-product attention.
dots = torch.einsum('bhnid,bhnjd->bhnij', bq, bk) * (bq.shape[-1] ** -0.5)
if self.config.rel_enc:
idx_diff = torch.clamp(bq_t[:, :, :, :, None] - bkv_t[:, :, :, None, :], min=0) #[b, h, nc, cs, cs2]
dots += 1-(2/np.log(self.seqlen+1))*torch.log(idx_diff.float()+1).half()
del idx_diff
mask = bq_t[:, :, :, :, None] < bkv_t[:, :, :, None, :]
dots += float_half(self.config, mask.float() * (-1e9))
del mask
# Mask out attention to self except when no other targets are available.
self_mask = bq_t[:, :, :, :, None] == bkv_t[:, :, :, None, :]
dots.masked_fill_(self_mask, - 1e3)
del self_mask
dots = F.softmax(dots, -1)
dots = F.dropout(dots, p=self.config.dropout_prob, training=self.training)
bo = torch.einsum('bhcij,bhcjd->bhcid', dots, bv)
so = torch.reshape(bo, (batch_size, h_q, -1, bo.shape[-1]))
class UnsortLogits(Function):
@staticmethod
def forward(ctx, so):
so = so.detach()
o = batched_index_select(so, undo_sort)
return o
@staticmethod
def backward(ctx, grad_x):
so_grad = batched_index_select(grad_x, sticker)
return so_grad
out = UnsortLogits.apply(so)
if self.config.all:
out = out[..., seqlen//2:,:]
return out, kwargs
class LSHAttention(nn.Module):
def __init__(self, config, shared_params=None, idx=0):
super(LSHAttention, self).__init__()
self.config = config
self.idx = idx
std = 1 / config.hidden_dim
if config.mode == 6:
self.means = torch.zeros(config.num_heads, config.num_clusters, config.hidden_dim // config.num_heads,
device=self.config.device).normal_(0, std).half()
self.means = self.means / self.means.norm(dim=-1, keepdim=True)
if config.all:
self.weights = nn.Parameter(
torch.zeros(config.num_heads, config.hidden_dim // config.num_heads, config.num_clusters,
device=self.config.device).normal_(0, std))
def forward(self, q, k, v, decoding=False, **kwargs):
b_q, h_q, t_q, dim_q = list(q.size())
b_k, h_k, t_k, dim_k = list(k.size())
if self.config.mode == 5: # lsh
random_rotations = k.new_zeros(h_k, dim_k, self.config.num_clusters // 2).normal_(0, 1)
rotated_vecs = torch.einsum('bhkd,hdc->bhkc', k, random_rotations)
rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1)
if self.config.mode == 6: # k-means
rotated_vecs = torch.einsum('bhkd,hcd->bhkc', k, self.means)
with torch.no_grad():
buckets = torch.argmax(rotated_vecs, dim=-1)
one_hot = F.one_hot(buckets, num_classes=self.config.num_clusters).half()
pre_mask = 1 - torch.einsum('bhkc, bhlc -> bhkl', one_hot, one_hot)
mask = float_half(self.config, pre_mask.float() * (-1e9))
if self.config.mode == 6:
pre_means = torch.einsum('bhkd,bhkc->bhkcd', k, one_hot).transpose(0, 1).contiguous().view(h_k, -1,
self.config.num_clusters,
dim_k).float().sum(
1)
self.means = (pre_means / pre_means.float().norm(dim=-1, keepdim=True)).half().detach()
qk = torch.einsum('bhqd,bhkd->bhqk', q, k)
qk *= dim_q ** -0.5
pre_mask = torch.ones(t_q, t_k, device=self.config.device).byte().triu_(t_k - t_q + 1)
mask += float_half(self.config, pre_mask.float() * (-1e9))
pre_mask = torch.ones(t_q, t_k, device=self.config.device).byte().triu_(t_k - t_q).tril_(t_k - t_q)
mask += float_half(self.config, pre_mask.float() * (-1e3))
qk += mask
sm_qk = F.softmax(qk, dim=-1)
sm_qk = F.dropout(sm_qk, p=self.config.dropout_prob, training=self.training)
o = torch.einsum('bhqk,bhkd->bhqd', sm_qk, v) if not self.config.multi_query else torch.einsum('bhqk,bkd->bhqd',
sm_qk,
v.squeeze(1))
return o, kwargs
class Block(nn.Module):
def __init__(self, config, shared_params, idx):
super(Block, self).__init__()
self.config = config
hidden_size = config.hidden_dim
inner_linear = 4 * config.hidden_dim
self.lnorm1 = LayerNorm(hidden_size)
if config.mode in [3,4,5,6]:
self.masked_attention = MultiHeadAttention(config, shared_params=shared_params, idx=idx)
else:
self.masked_attention = SparseAttention(config, shared_params=shared_params, idx=idx)
if not self.config.all:
self.lnorm2 = LayerNorm(hidden_size)
self.trans = nn.Sequential(nn.Linear(hidden_size, inner_linear), nn.ReLU(),
nn.Dropout(p=self.config.dropout_prob),
nn.Linear(inner_linear, hidden_size))
def forward(self, x, cache=None, decoding=False, **kwargs):
length = self.config.ctxt_length if self.training else self.config.ctxt_length_test
res = x
if cache is None and self.config.recurrent:
bs, _, dim = list(x.size())
if self.config.efficient_xl:
dim *= 2
cache = x.new_zeros(bs, length, dim)
x = self.lnorm1(x)
if self.config.recurrent and not self.config.efficient_xl:
x = torch.cat([cache, x], 1)
cache = x[:, -length:].detach()
x, cache_, kwargs = self.masked_attention(x, decoding=decoding, cache=cache, **kwargs)
if self.config.efficient_xl:
cache = cache_
if not self.config.all:
x = F.dropout(x, p=self.config.dropout_prob, training=self.training)
x.add_(res)
res = x
x = self.lnorm2(x)
x = self.trans(x)
x.add_(res)
return x, cache, kwargs
from collections import namedtuple
from torch.nn.modules import Module
from torch.nn.functional import log_softmax
class AdaptiveInput(Module):
def __init__(self, config, in_features, n_classes, cutoffs, div_value=4., head_bias=False):
super(AdaptiveInput, self).__init__()
cutoffs = list(cutoffs)
if (cutoffs != sorted(cutoffs)) \
or (min(cutoffs) <= 0) \
or (max(cutoffs) >= (n_classes - 1)) \
or (len(set(cutoffs)) != len(cutoffs)) \
or any([int(c) != c for c in cutoffs]):
raise ValueError("cutoffs should be a sequence of unique, positive "
"integers sorted in an increasing order, where "
"each value is between 1 and n_classes-1")
self.in_features = in_features
self.n_classes = n_classes
self.cutoffs = cutoffs + [n_classes]
self.div_value = div_value
self.head_bias = head_bias
self.config = config
self.n_clusters = len(self.cutoffs) - 1
self.head_size = self.cutoffs[0]
self.head = nn.Sequential(nn.Embedding(self.head_size, self.in_features),
nn.Linear(self.in_features, self.in_features, bias=self.head_bias))
self.tail = nn.ModuleList()
for i in range(self.n_clusters):
hsz = int(self.in_features // (self.div_value ** (i + 1)))
osz = self.cutoffs[i + 1] - self.cutoffs[i]
projection = nn.Sequential(
nn.Embedding(osz, hsz),
nn.Linear(hsz, self.in_features, bias=False),
)
self.tail.append(projection)
def forward(self, input):
used_rows = 0
input_size = list(input.size()) # [bs, len]
output = float_half(self.config, input.new_zeros(input_size + [self.in_features])).squeeze()
input = input.squeeze()
cutoff_values = [0] + self.cutoffs
for i in range(len(cutoff_values) - 1):
low_idx = cutoff_values[i]
high_idx = cutoff_values[i + 1]
input_mask = (input >= low_idx) & (input < high_idx)
row_indices = input_mask.nonzero().squeeze()
if row_indices.numel() == 0:
continue
out = self.head(input[input_mask] - low_idx) if i == 0 else self.tail[i - 1](input[input_mask] - low_idx)
output.index_copy_(0, row_indices, out)
used_rows += row_indices.numel()
if used_rows != input_size[1]:
raise RuntimeError("Target values should be in [0, {}], "
"but values in range [{}, {}] "
"were found. ".format(self.n_classes - 1,
input.min().item(),
input.max().item()))
return output.unsqueeze(0)
class LAttention(nn.Module):
def __init__(self, config, shared_params=None, idx=0):
super(LAttention, self).__init__()
self.config = config
self.R = shared_params[0]
if self.config.rel_enc:
self.w_r = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False)
self.u = shared_params[1]
self.v = shared_params[2]
self.idx = idx
def R_(self, t_k):
return self.R[:t_k].flip(0)
def forward(self, q, k, v, decoding=False):
b_q, h_q, t_q, dim_q = list(q.size())
b_k, h_k, t_k, dim_k = list(k.size())
q = q.view(b_q, h_q, t_q // self.config.tgt_length, self.config.tgt_length, dim_q) #
k = k.view(b_k, h_k, t_k // self.config.tgt_length, self.config.tgt_length, dim_k) #
v = v.view(b_k, h_k, t_k // self.config.tgt_length, self.config.tgt_length, dim_k) #
def f(x):
return torch.cat([x[:, :, :-1], x[:, :, 1:]], -2)
k = f(k)
v = f(v)
if self.config.rel_enc:
k_part = torch.einsum('bhcqd,bhckd->bhcqk', q + self.u.unsqueeze(2),
k) if not self.config.multi_query else torch.einsum('bhqd,bkd->bhqk', q + self.u,
k.squeeze(1))
tmp = torch.einsum('bhcqd,khd->bhcqk', q + self.v.unsqueeze(2),
self.w_r(self.R_(q.size(-2) * 2)).view(-1, h_q, dim_q))
wr_part = shift(tmp)
qk = k_part + wr_part
else:
qk = torch.einsum('bhcqd,bhckd->bhcqk', q, k)
qk *= dim_q ** -0.5
pre_mask = torch.ones(t_q, t_k, device=self.config.device).byte().triu_(t_k - t_q + 1)
mask = float_half(self.config, pre_mask.float() * (-1e9))
if self.config.share_qk:
pre_mask = torch.ones(t_q, t_k, device=self.config.device).byte().triu_(t_k - t_q).tril_(t_k - t_q)
mask += float_half(self.config, pre_mask.float() * (-1e3))
qk += mask
if True:
# random_rotations = k.new_zeros(h_k, dim_k, self.config.num_clusters // 2).normal_(0, 1)
random_rotations = torch.randn((h_k, dim_k, self.config.num_clusters // 2), device=self.config.device,
dtype=k.dtype)
rotated_vecs = torch.einsum('bhkd,hdc->bhkc', k, random_rotations)
rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1)
buckets = torch.argmax(rotated_vecs, dim=-1)
one_hot = F.one_hot(buckets, num_classes=self.config.num_clusters).half()
print(one_hot[0, 0].long().sum(0))
if self.config.mode == 3: # mode 3 = topk, mode 4 = full
qk = torch.where(qk >= qk.topk(32, dim=-1)[0][..., -1:], qk,
float_half(self.config, qk.new_ones(1).float() * (-1e9)))
pre_mask = torch.ones(t_q, t_q * 2, device=self.config.device).byte().triu_(t_q + 1)
mask = float_half(self.config, pre_mask.float() * (-1e9))
qk += mask
sm_qk = F.softmax(qk, dim=-1)
sm_qk = F.dropout(sm_qk, p=self.config.dropout_prob, training=self.training)
o = torch.einsum('bhcqk,bhckd->bhcqd', sm_qk, v) if not self.config.multi_query else torch.einsum(
'bhqk,bkd->bhqd', sm_qk, v.squeeze(1))
o = o.view(b_q, h_q, t_q, dim_q)
return o
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment