Created
August 13, 2020 09:38
-
-
Save AranKomat/759d74dbf21f12c1d417b11cd3e36f07 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn import Parameter | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.nn.init as init | |
import math | |
import numpy as np | |
from torch.autograd import Function | |
try: | |
from torch_scatter import scatter | |
except ImportError: | |
!pip install torch-scatter==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.4.0.html | |
from torch_scatter import scatter | |
class Config(object): | |
def __init__(self, dim, length, device): | |
self.device = device | |
self.hidden_dim = dim | |
self.seqlen = length | |
self.num_heads = 8 | |
self.head_dim = self.hidden_dim // self.num_heads | |
self.global_heads = [4, 4, 4, 4, 4, 4] | |
#self.global_heads = [0, 0, 0, 4, 4, 4] | |
self.num_updates_train = 1 | |
self.num_updates_test = 1 # the codes should be slightly modified for this to be 0 | |
self.window_size = 64 # 256 | |
self.dropout_prob = 0 | |
self.global_window_size = self.window_size | |
self.num_clusters = self.seqlen // self.global_window_size | |
def shift_(x): | |
# x = [*, t_q, t_k] | |
zero_pad = torch.zeros(*x.size()[:-1], x.size(-2), device=x.device, dtype=x.dtype) | |
x = torch.cat([x, zero_pad], -1) | |
l = x.size(-1) | |
x = x.view(*x.size()[:-2], -1) | |
zero_pad = torch.zeros(*x.size()[:-1], -x.size(-1) % (l - 1), device=x.device, dtype=x.dtype) | |
return torch.cat([x, zero_pad], -1).view(*x.size()[:-1], -1, l - 1) | |
def shift(x): | |
t_q = x.size()[-2] | |
return shift_(x)[..., :t_q, t_q - 1:] | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, config, idx=0): | |
super(MultiHeadAttention, self).__init__() | |
self.config = config | |
self.num_heads = config.num_heads | |
self.num_heads_k = config.num_heads | |
self.linear_qkv = nn.Linear(config.hidden_dim, config.hidden_dim * 3, bias=False) | |
self.linear_out = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False) | |
self.num_global_heads = config.global_heads[idx] | |
self.local_attention = LocalAttention(config, config.num_heads - self.num_global_heads, idx) | |
self.idx = idx | |
if self.num_global_heads > 0: | |
self.global_attention = GlobalAttention(config, self.num_global_heads, idx) | |
assert (config.hidden_dim % config.num_heads == 0) | |
def forward(self, inp, cache=None, decoding=False, **kwargs): | |
qkv = self.linear_qkv(inp) | |
batch, length, _ = list(qkv.size()) | |
qkv = qkv.view(batch, length, self.num_heads, -1).transpose(1, 2) | |
q, k, v = torch.split(qkv, self.config.head_dim, dim=-1) | |
out, _ = self.local_attention(q[:, self.num_global_heads:], k[:, self.num_global_heads:], | |
v[:, self.num_global_heads:], decoding=decoding, **kwargs) | |
if self.num_global_heads > 0: | |
out_, kwargs = self.global_attention(q[:, :self.num_global_heads], v[:, :self.num_global_heads], **kwargs) | |
out = torch.cat([out_, out], 1) | |
out = out.transpose(1, 2).contiguous() | |
out = self.linear_out(out.view(batch, length, self.config.hidden_dim)) | |
return out, cache, kwargs | |
def batched_bincount(index, num_classes, dim=-1): | |
shape = list(index.size()) | |
shape[dim] = num_classes | |
out = index.new_zeros(shape).float() | |
out.scatter_add_(dim, index, torch.ones_like(index).float()) | |
return out.type(index.dtype) | |
def similarity(x, means): | |
return torch.einsum('bhld,hcd->bhlc', x, means) | |
def squared_distance(x, means): # L2 dist ^ 2 | |
return 2 - 2 * similarity(x, means) | |
def dists_and_buckets(x, means): | |
dists = similarity(x, means) | |
dists, buckets = torch.max(dists, dim=-1) | |
return dists, buckets | |
def compute_se(x, means): | |
dists = squared_distance(x, means) | |
dists, buckets = torch.min(dists, dim=-1) | |
dists_dtype = dists.dtype | |
dists_float = dists.float() | |
return dists.sum(-1).type(dists_dtype) | |
def buckets_to_means(x, buckets, num_clusters): | |
b, h, l, d = list(x.size()) | |
means = buckets.new_zeros(b, h, num_clusters, d).float() # hcd | |
means.scatter_add_(-2, buckets.unsqueeze(-1).expand(-1, -1, -1, d), x.float()) | |
return F.normalize(means.sum(0, keepdim=True).type(x.dtype), dim=-1) | |
def kmeans_iter(x, means, num_clusters, compute_se_): # hard k-means single iter | |
dists, buckets = dists_and_buckets(x, means) | |
bins = batched_bincount(buckets, num_clusters).sum(0, keepdim=True) | |
zero_mask = bins.long() == 0 | |
means_ = buckets_to_means(x, buckets, num_clusters) | |
means = torch.where(zero_mask.unsqueeze(-1), means, means_) | |
means = means.squeeze(0) | |
se = compute_se(x, means) if compute_se_ else None | |
return means, se | |
def kmeans(config, x, means, num_clusters, training=True, reset=False, compute_se=False): | |
max_iters = config.num_updates_train if training else config.num_updates_test | |
if reset: | |
max_iters = 10 | |
means = x.transpose(0, 1).contiguous().view(x.size(1), -1, x.size(-1)) | |
means = means[:, torch.randperm(means.size(1), device=x.device)[:num_clusters]] | |
for idx in range(max_iters): | |
means, se = kmeans_iter(x, means, num_clusters, compute_se) | |
dists = similarity(x, means) | |
return means, dists, se | |
# pick the k points closest to each centroid and sort their indices in an ascending order for causal attention | |
def distribution(config, dists, means): | |
indices = dists.topk(k=config.window_size, dim=-2)[1].sort(dim=-2)[0].transpose(-2, -1) # [b, h, c, l] | |
return indices, means | |
class GlobalAttention(nn.Module): | |
def __init__(self, config, num_heads, idx=0): | |
super(GlobalAttention, self).__init__() | |
self.config = config | |
self.num_heads = num_heads | |
self.window_size = self.config.global_window_size | |
self.num_clusters = config.num_clusters | |
std = 1 / np.sqrt(config.hidden_dim) | |
self.R = nn.Parameter( | |
torch.zeros(self.window_size, num_heads, config.head_dim, | |
device=self.config.device).normal_(0, std)) | |
self.register_buffer('means', torch.zeros(num_heads, config.num_clusters, config.head_dim)) | |
self.register_buffer('init', torch.ones(1)) # 1 if the model was just initialized | |
def forward(self, qk, v, decoding=False, **kwargs): | |
with torch.no_grad(): | |
k = F.normalize(qk, dim=-1) | |
if self.init == 1: | |
reset = True | |
self.init.copy_(torch.zeros_like(self.init)) | |
else: | |
reset = False | |
means, dists, se = kmeans(self.config, k, self.means, self.num_clusters, training=self.training, reset=reset) | |
indices, means = distribution(self.config, dists, means) | |
self.means.copy_(means) | |
# TODO: do something here to make more k_i's to be included ... e.g. increase num_clusters for the missing ones? | |
indices = indices.contiguous().view(*indices.size()[:2], -1) # [b, h, l] | |
return self.attention(qk, v, indices, **kwargs) | |
def attention(self, qk, v, indices, **kwargs): | |
num_clusters = qk.size(-2) // self.window_size | |
batch_size, h_q, h_v, dim = qk.size(0), qk.size(1), v.size(1), qk.size(-1) | |
def batched_index_select(values, indices): | |
last_dim = values.shape[-1] | |
return values.gather(2, indices[:, :, :, None].expand(-1, -1, -1, last_dim)) | |
qk = batched_index_select(qk, indices) | |
v = batched_index_select(v, indices) | |
qk = torch.reshape(qk, | |
(batch_size, h_q, num_clusters, self.window_size, qk.shape[-1])) # [b, h, nc, cs, d] | |
v = torch.reshape(v, (batch_size, h_v, num_clusters, self.window_size, v.shape[-1])) | |
q = qk | |
k = F.normalize(qk, dim=-1) | |
# Relative attention | |
k_part = torch.einsum('bhnid,bhnjd->bhnij', q, k) | |
tmp = torch.einsum('bhnid,jhd->bhnij', q, self.R) | |
wr_part = shift(tmp) | |
dots = k_part + wr_part | |
# Mask out attention to self except when no other targets are available. | |
pre_mask = torch.ones(self.window_size, self.window_size, device=self.config.device).byte().triu_(1) | |
dots += (pre_mask.float() * (-1e9)).type(dots.dtype) | |
self_mask = torch.eye(self.window_size, device=dots.device) | |
dots += (self_mask.float() * (-1e3)).type(dots.dtype) | |
dots = F.softmax(dots, -1) # [b, h, c, q, k] | |
dots = F.dropout(dots, p=self.config.dropout_prob, training=self.training) | |
bo = torch.einsum('bhcij,bhcjd->bhcid', dots, v) | |
so = torch.reshape(bo, (batch_size, h_q, -1, bo.shape[-1])) | |
out = scatter(so.float(), indices.unsqueeze(-1).expand_as(so), -2, dim_size = self.config.seqlen, reduce='mean').type(so.dtype) | |
return out, kwargs | |
class LocalAttention(nn.Module): | |
def __init__(self, config, num_heads, idx=0): | |
super(LocalAttention, self).__init__() | |
self.config = config | |
std = math.sqrt(1 / config.hidden_dim) | |
self.R = nn.Parameter( | |
torch.zeros(config.window_size, num_heads, config.head_dim, | |
device=self.config.device).normal_(0, std)) | |
self.idx = idx | |
def forward(self, q, k, v, decoding=False, **kwargs): | |
b_q, h_q, t_q, dim_q = list(q.size()) | |
b_k, h_k, t_k, dim_k = list(k.size()) | |
window_size = self.config.window_size | |
q = q.view(b_q, h_q, -1, window_size // 2, dim_q) # | |
k = k.view(b_k, h_k, -1, window_size // 2, dim_k) # | |
v = v.view(b_k, h_k, -1, window_size // 2, dim_k) # | |
def f(x): | |
x_extra = F.pad(x[:, :, :-1, ...], pad=(0, 0, 0, 0, 1, 0)) | |
return torch.cat([x_extra, x], dim=3) | |
k = f(k) | |
v = f(v) | |
k_part = torch.einsum('bhcqd,bhckd->bhcqk', q, k) | |
tmp = torch.einsum('bhcqd,khd->bhcqk', q, self.R) | |
wr_part = shift(tmp) | |
qk = k_part + wr_part | |
qk *= dim_q ** -0.5 | |
pre_mask = torch.ones(window_size // 2, window_size, device=self.config.device).byte().triu_( | |
window_size // 2 + 1) | |
mask = (pre_mask.float() * (-1e9)).type(qk.dtype) | |
qk += mask | |
sm_qk = F.softmax(qk, dim=-1) | |
sm_qk = F.dropout(sm_qk, p=self.config.dropout_prob, training=self.training) | |
o = torch.einsum('bhcqk,bhckd->bhcqd', sm_qk, v) | |
o = o.view(b_q, h_q, t_q, dim_q) | |
return o, kwargs | |
dev = torch.device('cuda:0') | |
print(torch.cuda.get_device_name(0)) | |
batch_size, length, dim = 1, 4096, 512 | |
input = torch.randn(batch_size, length, dim, dtype=torch.float32, device=dev) | |
config = Config(dim, length, dev) | |
attention_layer = MultiHeadAttention(config, idx=0).to(dev) | |
attention_layer(input) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment