Last active
March 2, 2020 05:16
-
-
Save AranKomat/87a4ddcb293bd81ad1b2b5a8688df98c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn import Parameter | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.nn.init as init | |
import math | |
import numpy as np | |
from torch.autograd import Function | |
try: | |
from apex.normalization import FusedLayerNorm as LayerNorm | |
except ImportError: | |
import nn.LayerNorm as LayerNorm | |
from copy import copy | |
from time import time | |
#from torch_scatter import scatter | |
def float_half(config, x): | |
if config.fp16: | |
return x.half() | |
else: | |
return x.float() | |
def initialize(m): | |
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d): | |
init.xavier_normal_(m.weight.data) | |
if m.bias is not None: | |
m.bias.data.zero_() | |
def positional_embedding(config): | |
''' Sinusoid position encoding table ''' | |
n_position = 10000 | |
d_hid = config.hidden_dim | |
def cal_angle(position, hid_idx): | |
return position / np.power(n_position, 2 * (hid_idx // 2) / d_hid) | |
def get_posi_angle_vec(position): | |
return [cal_angle(position, hid_j) for hid_j in range(d_hid)] | |
sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) | |
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |
return float_half(config, torch.tensor(sinusoid_table, device=config.device)) | |
class Embedding(nn.Module): | |
def __init__(self, config): | |
super(Embedding, self).__init__() | |
self.config = config | |
self.hd = config.hidden_dim | |
if config.adaptive: | |
self.emb = AdaptiveInput(config, self.hd, config.vocab_size, config.cutoffs) | |
else: | |
self.emb = nn.Embedding(config.vocab_size, self.hd) | |
def forward(self, x, decoding=False): | |
if self.config.adaptive: | |
if decoding: # TODO: incomplete here | |
return | |
else: | |
return self.emb(x.contiguous().view(1, -1)).squeeze().view(x.size(0), -1, self.hd) | |
else: | |
return self.emb(x) | |
class Softmax(nn.Module): | |
def __init__(self, config): | |
super(Softmax, self).__init__() | |
self.config = config | |
self.hd = config.hidden_dim | |
if config.adaptive: | |
# self.softmax = AdaptiveLogSoftmaxWithLoss(self.hd, config.vocab_size, config.cutoffs) | |
self.softmax = nn.AdaptiveLogSoftmaxWithLoss(self.hd, config.vocab_size, config.cutoffs) | |
else: | |
self.linear = nn.Linear(self.hd, config.vocab_size, bias=False) | |
def forward(self, x, tgt=None, decoding=False): | |
if self.config.adaptive: | |
# print(x.size(),tgt.size()) | |
# tgt = tgt[:, 70:].contiguous() | |
tgt = tgt.contiguous().view(-1) | |
if decoding: # TODO: incomplete here | |
return self.softmax.log_prob(x, tgt) | |
else: | |
# x = x[:,70:].contiguous() | |
return self.softmax(x.view(-1, self.hd), tgt) | |
else: | |
x = self.linear(x) | |
x = F.log_softmax(x, -1) | |
return x | |
class ED(nn.Module): | |
def __init__(self, config): | |
super(ED, self).__init__() | |
self.config = config | |
c1 = copy(config) | |
c2 = copy(config) | |
c1.encoder = True | |
c2.attend = True | |
# c2.cached = True | |
self.emb = Embedding(config) | |
self.e = Transformer(c1) | |
self.d = Transformer(c2) | |
self.softmax = Softmax(config) | |
def encoding(self, encoder_inp): | |
return self.e(self.emb(encoder_inp))[1] | |
def forward(self, encoder_cache, inp, cache=None): | |
x, cache, l = self.d(self.emb(inp), cache=cache, encoder_cache=encoder_cache) | |
x = self.softmax(x) | |
return x, cache, l | |
class Decoder(nn.Module): | |
def __init__(self, config): | |
super(Decoder, self).__init__() | |
self.config = config | |
self.td = Transformer(config) | |
self.softmax = Softmax(config) | |
self.emb = Embedding(config) | |
def forward(self, inp, cache=None, decoding=False, cur_time=0, tgt=None, **kwargs): | |
x = inp | |
x = self.emb(x) | |
x, c, kwargs = self.td(x, cache=cache, decoding=decoding, cur_time=cur_time, **kwargs) | |
return self.softmax(x, tgt=tgt)[1], c, kwargs | |
class Transformer(nn.Module): | |
def __init__(self, config, weights=None): | |
super(Transformer, self).__init__() | |
self.config = config | |
std = math.sqrt(1 / config.hidden_dim) | |
self.u = nn.Parameter( | |
torch.zeros(1, config.num_heads, 1, config.hidden_dim // config.num_heads, | |
device=self.config.device).normal_(0, std)) | |
self.v = nn.Parameter( | |
torch.zeros(1, config.num_heads, 1, config.hidden_dim // config.num_heads, | |
device=self.config.device).normal_(0, std)) | |
self.R = positional_embedding(config) | |
shared_params = [self.R, self.u, self.v] | |
self.blocks = nn.ModuleList([Block(config, shared_params, i) | |
for i in range(config.depth)]) | |
def forward(self, x, cache=None, cur_time=0, decoding=False, **kwargs): | |
total_loss = 0 | |
cache = [None for _ in range(self.config.depth)] if cache is None else cache | |
new_cache = [None for _ in range(self.config.depth)] | |
if not self.config.rel_enc: | |
x += self.R[:x.size(1)] | |
for i in range(self.config.depth): | |
x = F.dropout(x, p=self.config.dropout_prob, training=self.training) | |
x, new_cache[i], kwargs = self.blocks[i](x, cache[i], decoding=decoding, **kwargs) | |
return x, new_cache, kwargs | |
def shift_(x): | |
# x = [*, t_q, t_k] | |
zero_pad = torch.zeros(*x.size()[:-1], x.size(-2), device=x.device, dtype=x.dtype) | |
x = torch.cat([x, zero_pad], -1) | |
l = x.size(-1) | |
x = x.view(*x.size()[:-2], -1) | |
zero_pad = torch.zeros(*x.size()[:-1], -x.size(-1) % (l - 1), device=x.device, dtype=x.dtype) | |
return torch.cat([x, zero_pad], -1).view(*x.size()[:-1], -1, l - 1) | |
def shift(x): | |
t_q = x.size()[-2] | |
return shift_(x)[..., :t_q, t_q - 1:] | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, config, shared_params=None, idx=0): | |
super(MultiHeadAttention, self).__init__() | |
self.config = config | |
self.num_heads = config.num_heads | |
self.num_heads_k = 1 if config.multi_query else config.num_heads | |
self.lin_q = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False) | |
if config.share_qk: | |
self.lin_v = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False) | |
elif config.multi_query: | |
self.lin_kv = nn.Linear(config.hidden_dim, (config.hidden_dim * 2) // config.num_heads, bias=False) | |
else: | |
self.lin_kv = nn.Linear(config.hidden_dim, config.hidden_dim * 2, bias=False) | |
self.linear_out = nn.Linear(config.hidden_dim, config.hidden_dim) | |
if self.config.mode in [3, 4]: | |
self.attention = Attention(config, shared_params=shared_params, idx=idx) | |
else: | |
self.attention = LSHAttention(config, shared_params=shared_params, idx=idx) | |
assert (config.hidden_dim % config.num_heads == 0) | |
def forward(self, inp, cache=None, decoding=False, **kwargs): | |
ctxt_length = self.config.ctxt_length if self.training else self.config.ctxt_length_test | |
inp_ = inp[:, ctxt_length:] if self.config.recurrent else inp | |
if self.config.share_qk: | |
q = self.lin_q(inp) | |
k = q / (torch.norm(q, dim=-1, keepdim=True) + 1e-6) | |
q = q[:, ctxt_length:] | |
v = self.lin_v(inp) | |
else: | |
q = self.lin_q(inp_) | |
kv = self.lin_kv(inp) | |
if self.config.efficient_xl: # TODO: check this | |
kv = torch.cat([cache, kv], 1) | |
cache = kv[:, -ctxt_length:].detach() | |
k, v = torch.split(kv, self.config.hidden_dim // 8, dim=-1) if self.config.multi_query else torch.split(kv, | |
self.config.hidden_dim, | |
dim=-1) | |
if decoding: | |
if cache[0] is None: | |
cache = [k, v] | |
else: | |
k = torch.cat([cache[0], k], 1) | |
v = torch.cat([cache[1], v], 1) | |
cache = [k, v] | |
batch, length, dim = list(q.size()) | |
q = q.view(batch, q.size(1), self.num_heads, -1).transpose(1, 2) | |
k = k.view(batch, k.size(1), self.num_heads_k, -1).transpose(1, 2) | |
v = v.view(batch, v.size(1), self.num_heads_k, -1).transpose(1, 2) | |
out, kwargs = self.attention(q, k, v, decoding=decoding, **kwargs) | |
out = out.transpose(1, 2).contiguous().view(batch, length, dim) | |
out = self.linear_out(out) | |
return out, cache, kwargs | |
class Attention(nn.Module): | |
def __init__(self, config, shared_params=None, idx=0): | |
super(Attention, self).__init__() | |
self.config = config | |
self.R = shared_params[0] | |
if self.config.rel_enc: | |
self.w_r = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False) | |
self.u = shared_params[1] | |
self.v = shared_params[2] | |
self.idx = idx | |
std = 1 / config.hidden_dim | |
self.means1 = torch.zeros(config.batch_size, config.num_heads, config.num_clusters, | |
config.hidden_dim // config.num_heads, | |
device=self.config.device).normal_(0, std).half() | |
self.means1 = self.means1 / self.means1.norm(dim=-1, keepdim=True) | |
self.means2 = torch.zeros(config.num_heads, config.num_clusters, config.hidden_dim // config.num_heads, | |
device=self.config.device).normal_(0, std).half() | |
self.means2 = self.means2 / self.means2.norm(dim=-1, keepdim=True) | |
def R_(self, t_k): | |
return self.R[:t_k].flip(0) | |
def forward(self, q, k, v, decoding=False, **kwargs): | |
aux = 0 | |
b_q, h_q, t_q, dim_q = list(q.size()) | |
b_k, h_k, t_k, dim_k = list(k.size()) | |
if self.idx == 3 and False: | |
# random_rotations = k.new_zeros(h_k, dim_k, self.config.num_clusters // 2).normal_(0, 1) | |
random_rotations = torch.randn((h_k, dim_k, self.config.num_clusters // 2), device=self.config.device, | |
dtype=k.dtype) | |
rotated_vecs = torch.einsum('bhkd,hdc->bhkc', k, random_rotations) | |
rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) | |
buckets = torch.argmax(rotated_vecs, dim=-1) | |
one_hot = F.one_hot(buckets, num_classes=self.config.num_clusters).half() | |
print(one_hot[0, 0].long().sum(0)) | |
rotated_vecs = torch.einsum('bhkd,hcd->bhkc', k, self.means2) | |
buckets = torch.argmax(rotated_vecs, dim=-1) | |
one_hot = F.one_hot(buckets, num_classes=self.config.num_clusters).half() | |
print(one_hot[0, 0].long().sum(0)) | |
pre_means = torch.einsum('bhkd,bhkc->bhkcd', k, one_hot).transpose(0, 1).contiguous().view(h_k, -1, | |
self.config.num_clusters, | |
dim_k).float().sum( | |
1) | |
self.means2 = (pre_means / pre_means.float().norm(dim=-1, keepdim=True)).half().detach() | |
pre_mask = 1 - torch.einsum('bhkc, bhlc -> bhkl', one_hot, one_hot) | |
mask = float_half(self.config, pre_mask.float() * (-1e9)) | |
if self.config.rel_enc: | |
k_part = torch.einsum('bhqd,bhkd->bhqk', q + self.u, k) if not self.config.multi_query else torch.einsum( | |
'bhqd,bkd->bhqk', q + self.u, k.squeeze(1)) | |
tmp = torch.einsum('bhqd,khd->bhqk', q + self.v, self.w_r(self.R_(t_k)).view(-1, h_q, dim_q)) | |
wr_part = shift(tmp) | |
qk = k_part + wr_part | |
else: | |
qk = torch.einsum('bhqd,bhkd->bhqk', q, k) | |
qk *= dim_q ** -0.5 | |
pre_mask = torch.ones(t_q, t_k, device=self.config.device).byte().triu_(t_k - t_q + 1) | |
mask = float_half(self.config, pre_mask.float() * (-1e9)) | |
if self.config.share_qk: | |
pre_mask = torch.ones(t_q, t_k, device=self.config.device).byte().triu_(t_k - t_q).tril_(t_k - t_q) | |
mask += float_half(self.config, pre_mask.float() * (-1e3)) | |
qk += mask | |
if self.config.mode == 3: # mode 3 = topk, mode 4 = full | |
qk = torch.where(qk >= qk.topk(32, dim=-1)[0][..., -1:], qk, | |
float_half(self.config, qk.new_ones(1).float() * (-1e9))) | |
sm_qk = F.softmax(qk, dim=-1) | |
sm_qk = F.dropout(sm_qk, p=self.config.dropout_prob, training=self.training) | |
o = torch.einsum('bhqk,bhkd->bhqd', sm_qk, v) if not self.config.multi_query else torch.einsum( | |
'bhqk,bkd->bhqd', | |
sm_qk, | |
v.squeeze(1)) | |
return o, kwargs | |
def batched_index_select(values, indices): | |
last_dim = values.shape[-1] | |
return values.gather(2, indices[:, :, :, None].expand(-1, -1, -1, last_dim)) | |
class SparseAttention(nn.Module): | |
def __init__(self, config, shared_params=None, idx=0): | |
super(SparseAttention, self).__init__() | |
self.config = config | |
self.num_heads = config.num_heads | |
self.num_heads_v = 1 if config.multi_query else config.num_heads | |
self.lin_q = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False) | |
self.lin_v = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False) | |
self.linear_out = nn.Linear(config.hidden_dim, config.hidden_dim) | |
std = 1 / np.sqrt(config.hidden_dim) | |
if config.all: | |
self.seqlen = 2*self.config.tgt_length | |
assert self.config.num_weights * self.config.num_clusters == self.seqlen | |
else: | |
self.seqlen = self.config.tgt_length | |
if config.rel_enc: | |
self.rel_enc = nn.Parameter(torch.zeros(config.num_heads, self.seqlen, device=config.device).normal_(0, std)) | |
if config.mode == 2: | |
self.means = torch.zeros(config.num_heads, config.num_clusters, config.hidden_dim // config.num_heads, | |
device=self.config.device).normal_(0, std).half() | |
self.means = F.normalize(self.means, dim=-1) | |
if config.hierarchical: | |
self.means2 = torch.zeros(config.num_heads, config.num_clusters2, config.hidden_dim // config.num_heads, | |
device=self.config.device).normal_(0, std).half() | |
self.means2 = F.normalize(self.means2, dim=-1) | |
NotImplementedError | |
if config.all: | |
self.M_k = nn.Parameter(torch.zeros(config.num_heads, config.num_weights*config.num_clusters, | |
config.config.hidden_dim // config.num_heads, | |
device=self.config.device).normal_(0, std)) | |
self.M_v = nn.Parameter(torch.zeros(config.num_heads, config.num_weights*config.num_clusters, | |
config.config.hidden_dim // config.num_heads, | |
device=self.config.device).normal_(0, std)) | |
def forward(self, inp, cache=None, decoding=False, **kwargs): | |
ctxt_length = self.config.ctxt_length if self.training else self.config.ctxt_length_test | |
inp_ = inp[:, ctxt_length:] if self.config.recurrent else inp | |
q = self.lin_q(inp_) | |
v = self.lin_v(inp) | |
batch, length, dim = list(q.size()) | |
q = q.view(batch, q.size(1), self.num_heads, -1).transpose(1, 2) | |
v = v.view(batch, v.size(1), self.num_heads_v, -1).transpose(1, 2) | |
out, kwargs = self.attention(q, v, decoding=decoding, **kwargs) | |
out = out.transpose(1, 2).contiguous().view(batch, length, dim) | |
out = self.linear_out(out) | |
return out, cache, kwargs | |
def attention(self, qk, v, decoding=False, **kwargs): | |
b_q, h_q, t_q, dim_q = list(qk.size()) | |
q = qk | |
k = F.normalize(qk, dim=-1) | |
with torch.no_grad(): | |
if self.config.mode == 1: # lsh | |
random_rotations = k.new_zeros(h_q, dim_q, self.config.num_clusters // 2).normal_(0, 1) | |
rotated_vecs = torch.einsum('bhkd,hdc->bhkc', k, random_rotations) | |
rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) | |
if self.config.mode == 2: # k-means | |
rotated_vecs = torch.einsum('bhkd,hcd->bhkc', k, self.means) | |
buckets = torch.argmax(rotated_vecs, dim=-1) #bhk | |
#print(torch.bincount(buckets.view(-1))[:10]) | |
if self.config.mode == 2: #[b,h,k,d] x [b,h,k] -> [b,h,c,d] | |
k_float = k.float() | |
src = k_float.new_zeros(b_q, h_q, self.config.num_clusters, dim_q) | |
src.scatter_add_(2, buckets.unsqueeze(-1).expand(-1,-1,-1, dim_q), k_float) | |
#pre_means = scatter(k, buckets, dim=2, reduce='mean') # mean version of scatter_add | |
#self.means = (src / src.float().norm(dim=-1, keepdim=True)).half() | |
self.means = F.normalize(src.mean(0), dim=-1).half() | |
del src | |
return self.attention_(q, k, v, buckets, **kwargs) | |
def attention_(self, q, k, v, buckets, **kwargs): | |
with torch.no_grad(): | |
seqlen = self.seqlen | |
cluster_size = seqlen // self.config.num_clusters | |
if self.config.all: | |
cluster_size /= 2 | |
q = torch.cat([self.M_k.unsqueeze(0).expand(q.size(0), -1, -1, -1), q], -2) | |
k = torch.cat([self.M_k.unsqueeze(0).expand(k.size(0), -1, -1, -1), k], -2) | |
v = torch.cat([self.M_v.unsqueeze(0).expand(v.size(0), -1, -1, -1), v], -2) | |
M_buckets = torch.arange(self.config.num_clusters, device=self.config.device).unsqueeze(0).expand(self.config.num_weights, -1).view(-1).unsqueeze(0).unsqueeze(0) | |
buckets = torch.cat([M_buckets.expand(q.size(0), q.size(1), -1), buckets], -1) | |
batch_size, h_q, h_v = q.size(0), q.size(1), v.size(1) | |
ticker = torch.arange(seqlen, device=self.config.device) | |
buckets_and_t = seqlen * buckets + ticker | |
# Hash-based sort ("s" at the start of variable names means "sorted") | |
_, sticker = buckets_and_t.sort(-1) | |
_, undo_sort = sticker.sort(-1) | |
st = sticker | |
#TODO: see if view gather faster than naive gather | |
sq = batched_index_select(q, st) # q [b, h, k, d], st [b, h, k] | |
sk = batched_index_select(k, st) | |
sv = batched_index_select(v, st) | |
# Split off a "bin" axis so that attention only occurs within chunks. | |
bq_t = bkv_t = torch.reshape(st, (batch_size, h_q, self.config.num_clusters, cluster_size)) #[b, h, nc, cs] | |
bq = torch.reshape(sq, (batch_size, h_q, self.config.num_clusters, cluster_size, sq.shape[-1])) #[b, h, nc, cs, d] | |
bk = torch.reshape(sk, (batch_size, h_q, self.config.num_clusters, cluster_size, sk.shape[-1])) | |
bv = torch.reshape(sv, (batch_size, h_v, self.config.num_clusters, cluster_size, sv.shape[-1])) | |
# Allow each chunk to attend within itself, and also one chunk back. Chunk | |
# boundaries might occur in the middle of a sequence of items from the | |
# same bucket, so this increases the chances of attending to relevant items. | |
def look_one_back(x): | |
if self.config.all: | |
x_extra = torch.cat([x[:, :, -1:, ...], x[:, :, :-1, ...]], dim=2) | |
x_extra2 = torch.cat([x[:, :, -2:, ...], x[:, :, :-2, ...]], dim=2) | |
return torch.cat([x_extra2, x_extra, x], dim=3) | |
else: | |
x_extra = torch.cat([x[:, :, -1:, ...], x[:, :, :-1, ...]], dim=2) | |
return torch.cat([x_extra, x], dim=3) | |
bk = look_one_back(bk) | |
bv = look_one_back(bv) | |
bkv_t = look_one_back(bkv_t) | |
# Dot-product attention. | |
dots = torch.einsum('bhnid,bhnjd->bhnij', bq, bk) * (bq.shape[-1] ** -0.5) | |
if self.config.rel_enc: | |
idx_diff = torch.clamp(bq_t[:, :, :, :, None] - bkv_t[:, :, :, None, :], min=0) #[b, h, nc, cs, cs2] | |
dots += 1-(2/np.log(self.seqlen+1))*torch.log(idx_diff.float()+1).half() | |
del idx_diff | |
mask = bq_t[:, :, :, :, None] < bkv_t[:, :, :, None, :] | |
dots += float_half(self.config, mask.float() * (-1e9)) | |
del mask | |
# Mask out attention to self except when no other targets are available. | |
self_mask = bq_t[:, :, :, :, None] == bkv_t[:, :, :, None, :] | |
dots.masked_fill_(self_mask, - 1e3) | |
del self_mask | |
dots = F.softmax(dots, -1) | |
dots = F.dropout(dots, p=self.config.dropout_prob, training=self.training) | |
bo = torch.einsum('bhcij,bhcjd->bhcid', dots, bv) | |
so = torch.reshape(bo, (batch_size, h_q, -1, bo.shape[-1])) | |
class UnsortLogits(Function): | |
@staticmethod | |
def forward(ctx, so): | |
so = so.detach() | |
o = batched_index_select(so, undo_sort) | |
return o | |
@staticmethod | |
def backward(ctx, grad_x): | |
so_grad = batched_index_select(grad_x, sticker) | |
return so_grad | |
out = UnsortLogits.apply(so) | |
if self.config.all: | |
out = out[..., seqlen//2:,:] | |
return out, kwargs | |
class LSHAttention(nn.Module): | |
def __init__(self, config, shared_params=None, idx=0): | |
super(LSHAttention, self).__init__() | |
self.config = config | |
self.idx = idx | |
std = 1 / config.hidden_dim | |
if config.mode == 6: | |
self.means = torch.zeros(config.num_heads, config.num_clusters, config.hidden_dim // config.num_heads, | |
device=self.config.device).normal_(0, std).half() | |
self.means = self.means / self.means.norm(dim=-1, keepdim=True) | |
if config.all: | |
self.weights = nn.Parameter( | |
torch.zeros(config.num_heads, config.hidden_dim // config.num_heads, config.num_clusters, | |
device=self.config.device).normal_(0, std)) | |
def forward(self, q, k, v, decoding=False, **kwargs): | |
b_q, h_q, t_q, dim_q = list(q.size()) | |
b_k, h_k, t_k, dim_k = list(k.size()) | |
if self.config.mode == 5: # lsh | |
random_rotations = k.new_zeros(h_k, dim_k, self.config.num_clusters // 2).normal_(0, 1) | |
rotated_vecs = torch.einsum('bhkd,hdc->bhkc', k, random_rotations) | |
rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) | |
if self.config.mode == 6: # k-means | |
rotated_vecs = torch.einsum('bhkd,hcd->bhkc', k, self.means) | |
with torch.no_grad(): | |
buckets = torch.argmax(rotated_vecs, dim=-1) | |
one_hot = F.one_hot(buckets, num_classes=self.config.num_clusters).half() | |
pre_mask = 1 - torch.einsum('bhkc, bhlc -> bhkl', one_hot, one_hot) | |
mask = float_half(self.config, pre_mask.float() * (-1e9)) | |
if self.config.mode == 6: | |
pre_means = torch.einsum('bhkd,bhkc->bhkcd', k, one_hot).transpose(0, 1).contiguous().view(h_k, -1, | |
self.config.num_clusters, | |
dim_k).float().sum( | |
1) | |
self.means = (pre_means / pre_means.float().norm(dim=-1, keepdim=True)).half().detach() | |
qk = torch.einsum('bhqd,bhkd->bhqk', q, k) | |
qk *= dim_q ** -0.5 | |
pre_mask = torch.ones(t_q, t_k, device=self.config.device).byte().triu_(t_k - t_q + 1) | |
mask += float_half(self.config, pre_mask.float() * (-1e9)) | |
pre_mask = torch.ones(t_q, t_k, device=self.config.device).byte().triu_(t_k - t_q).tril_(t_k - t_q) | |
mask += float_half(self.config, pre_mask.float() * (-1e3)) | |
qk += mask | |
sm_qk = F.softmax(qk, dim=-1) | |
sm_qk = F.dropout(sm_qk, p=self.config.dropout_prob, training=self.training) | |
o = torch.einsum('bhqk,bhkd->bhqd', sm_qk, v) if not self.config.multi_query else torch.einsum('bhqk,bkd->bhqd', | |
sm_qk, | |
v.squeeze(1)) | |
return o, kwargs | |
class Block(nn.Module): | |
def __init__(self, config, shared_params, idx): | |
super(Block, self).__init__() | |
self.config = config | |
hidden_size = config.hidden_dim | |
inner_linear = 4 * config.hidden_dim | |
self.lnorm1 = LayerNorm(hidden_size) | |
if config.mode in [3,4,5,6]: | |
self.masked_attention = MultiHeadAttention(config, shared_params=shared_params, idx=idx) | |
else: | |
self.masked_attention = SparseAttention(config, shared_params=shared_params, idx=idx) | |
if not self.config.all: | |
self.lnorm2 = LayerNorm(hidden_size) | |
self.trans = nn.Sequential(nn.Linear(hidden_size, inner_linear), nn.ReLU(), | |
nn.Dropout(p=self.config.dropout_prob), | |
nn.Linear(inner_linear, hidden_size)) | |
def forward(self, x, cache=None, decoding=False, **kwargs): | |
length = self.config.ctxt_length if self.training else self.config.ctxt_length_test | |
res = x | |
if cache is None and self.config.recurrent: | |
bs, _, dim = list(x.size()) | |
if self.config.efficient_xl: | |
dim *= 2 | |
cache = x.new_zeros(bs, length, dim) | |
x = self.lnorm1(x) | |
if self.config.recurrent and not self.config.efficient_xl: | |
x = torch.cat([cache, x], 1) | |
cache = x[:, -length:].detach() | |
x, cache_, kwargs = self.masked_attention(x, decoding=decoding, cache=cache, **kwargs) | |
if self.config.efficient_xl: | |
cache = cache_ | |
if not self.config.all: | |
x = F.dropout(x, p=self.config.dropout_prob, training=self.training) | |
x.add_(res) | |
res = x | |
x = self.lnorm2(x) | |
x = self.trans(x) | |
x.add_(res) | |
return x, cache, kwargs | |
from collections import namedtuple | |
from torch.nn.modules import Module | |
from torch.nn.functional import log_softmax | |
class AdaptiveInput(Module): | |
def __init__(self, config, in_features, n_classes, cutoffs, div_value=4., head_bias=False): | |
super(AdaptiveInput, self).__init__() | |
cutoffs = list(cutoffs) | |
if (cutoffs != sorted(cutoffs)) \ | |
or (min(cutoffs) <= 0) \ | |
or (max(cutoffs) >= (n_classes - 1)) \ | |
or (len(set(cutoffs)) != len(cutoffs)) \ | |
or any([int(c) != c for c in cutoffs]): | |
raise ValueError("cutoffs should be a sequence of unique, positive " | |
"integers sorted in an increasing order, where " | |
"each value is between 1 and n_classes-1") | |
self.in_features = in_features | |
self.n_classes = n_classes | |
self.cutoffs = cutoffs + [n_classes] | |
self.div_value = div_value | |
self.head_bias = head_bias | |
self.config = config | |
self.n_clusters = len(self.cutoffs) - 1 | |
self.head_size = self.cutoffs[0] | |
self.head = nn.Sequential(nn.Embedding(self.head_size, self.in_features), | |
nn.Linear(self.in_features, self.in_features, bias=self.head_bias)) | |
self.tail = nn.ModuleList() | |
for i in range(self.n_clusters): | |
hsz = int(self.in_features // (self.div_value ** (i + 1))) | |
osz = self.cutoffs[i + 1] - self.cutoffs[i] | |
projection = nn.Sequential( | |
nn.Embedding(osz, hsz), | |
nn.Linear(hsz, self.in_features, bias=False), | |
) | |
self.tail.append(projection) | |
def forward(self, input): | |
used_rows = 0 | |
input_size = list(input.size()) # [bs, len] | |
output = float_half(self.config, input.new_zeros(input_size + [self.in_features])).squeeze() | |
input = input.squeeze() | |
cutoff_values = [0] + self.cutoffs | |
for i in range(len(cutoff_values) - 1): | |
low_idx = cutoff_values[i] | |
high_idx = cutoff_values[i + 1] | |
input_mask = (input >= low_idx) & (input < high_idx) | |
row_indices = input_mask.nonzero().squeeze() | |
if row_indices.numel() == 0: | |
continue | |
out = self.head(input[input_mask] - low_idx) if i == 0 else self.tail[i - 1](input[input_mask] - low_idx) | |
output.index_copy_(0, row_indices, out) | |
used_rows += row_indices.numel() | |
if used_rows != input_size[1]: | |
raise RuntimeError("Target values should be in [0, {}], " | |
"but values in range [{}, {}] " | |
"were found. ".format(self.n_classes - 1, | |
input.min().item(), | |
input.max().item())) | |
return output.unsqueeze(0) | |
class LAttention(nn.Module): | |
def __init__(self, config, shared_params=None, idx=0): | |
super(LAttention, self).__init__() | |
self.config = config | |
self.R = shared_params[0] | |
if self.config.rel_enc: | |
self.w_r = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False) | |
self.u = shared_params[1] | |
self.v = shared_params[2] | |
self.idx = idx | |
def R_(self, t_k): | |
return self.R[:t_k].flip(0) | |
def forward(self, q, k, v, decoding=False): | |
b_q, h_q, t_q, dim_q = list(q.size()) | |
b_k, h_k, t_k, dim_k = list(k.size()) | |
q = q.view(b_q, h_q, t_q // self.config.tgt_length, self.config.tgt_length, dim_q) # | |
k = k.view(b_k, h_k, t_k // self.config.tgt_length, self.config.tgt_length, dim_k) # | |
v = v.view(b_k, h_k, t_k // self.config.tgt_length, self.config.tgt_length, dim_k) # | |
def f(x): | |
return torch.cat([x[:, :, :-1], x[:, :, 1:]], -2) | |
k = f(k) | |
v = f(v) | |
if self.config.rel_enc: | |
k_part = torch.einsum('bhcqd,bhckd->bhcqk', q + self.u.unsqueeze(2), | |
k) if not self.config.multi_query else torch.einsum('bhqd,bkd->bhqk', q + self.u, | |
k.squeeze(1)) | |
tmp = torch.einsum('bhcqd,khd->bhcqk', q + self.v.unsqueeze(2), | |
self.w_r(self.R_(q.size(-2) * 2)).view(-1, h_q, dim_q)) | |
wr_part = shift(tmp) | |
qk = k_part + wr_part | |
else: | |
qk = torch.einsum('bhcqd,bhckd->bhcqk', q, k) | |
qk *= dim_q ** -0.5 | |
pre_mask = torch.ones(t_q, t_k, device=self.config.device).byte().triu_(t_k - t_q + 1) | |
mask = float_half(self.config, pre_mask.float() * (-1e9)) | |
if self.config.share_qk: | |
pre_mask = torch.ones(t_q, t_k, device=self.config.device).byte().triu_(t_k - t_q).tril_(t_k - t_q) | |
mask += float_half(self.config, pre_mask.float() * (-1e3)) | |
qk += mask | |
if True: | |
# random_rotations = k.new_zeros(h_k, dim_k, self.config.num_clusters // 2).normal_(0, 1) | |
random_rotations = torch.randn((h_k, dim_k, self.config.num_clusters // 2), device=self.config.device, | |
dtype=k.dtype) | |
rotated_vecs = torch.einsum('bhkd,hdc->bhkc', k, random_rotations) | |
rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) | |
buckets = torch.argmax(rotated_vecs, dim=-1) | |
one_hot = F.one_hot(buckets, num_classes=self.config.num_clusters).half() | |
print(one_hot[0, 0].long().sum(0)) | |
if self.config.mode == 3: # mode 3 = topk, mode 4 = full | |
qk = torch.where(qk >= qk.topk(32, dim=-1)[0][..., -1:], qk, | |
float_half(self.config, qk.new_ones(1).float() * (-1e9))) | |
pre_mask = torch.ones(t_q, t_q * 2, device=self.config.device).byte().triu_(t_q + 1) | |
mask = float_half(self.config, pre_mask.float() * (-1e9)) | |
qk += mask | |
sm_qk = F.softmax(qk, dim=-1) | |
sm_qk = F.dropout(sm_qk, p=self.config.dropout_prob, training=self.training) | |
o = torch.einsum('bhcqk,bhckd->bhcqd', sm_qk, v) if not self.config.multi_query else torch.einsum( | |
'bhqk,bkd->bhqd', sm_qk, v.squeeze(1)) | |
o = o.view(b_q, h_q, t_q, dim_q) | |
return o |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment