Skip to content

Instantly share code, notes, and snippets.

@AranKomat
Created August 13, 2020 09:38
Show Gist options
  • Save AranKomat/759d74dbf21f12c1d417b11cd3e36f07 to your computer and use it in GitHub Desktop.
Save AranKomat/759d74dbf21f12c1d417b11cd3e36f07 to your computer and use it in GitHub Desktop.
import torch
from torch.nn import Parameter
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import math
import numpy as np
from torch.autograd import Function
try:
from torch_scatter import scatter
except ImportError:
!pip install torch-scatter==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.4.0.html
from torch_scatter import scatter
class Config(object):
def __init__(self, dim, length, device):
self.device = device
self.hidden_dim = dim
self.seqlen = length
self.num_heads = 8
self.head_dim = self.hidden_dim // self.num_heads
self.global_heads = [4, 4, 4, 4, 4, 4]
#self.global_heads = [0, 0, 0, 4, 4, 4]
self.num_updates_train = 1
self.num_updates_test = 1 # the codes should be slightly modified for this to be 0
self.window_size = 64 # 256
self.dropout_prob = 0
self.global_window_size = self.window_size
self.num_clusters = self.seqlen // self.global_window_size
def shift_(x):
# x = [*, t_q, t_k]
zero_pad = torch.zeros(*x.size()[:-1], x.size(-2), device=x.device, dtype=x.dtype)
x = torch.cat([x, zero_pad], -1)
l = x.size(-1)
x = x.view(*x.size()[:-2], -1)
zero_pad = torch.zeros(*x.size()[:-1], -x.size(-1) % (l - 1), device=x.device, dtype=x.dtype)
return torch.cat([x, zero_pad], -1).view(*x.size()[:-1], -1, l - 1)
def shift(x):
t_q = x.size()[-2]
return shift_(x)[..., :t_q, t_q - 1:]
class MultiHeadAttention(nn.Module):
def __init__(self, config, idx=0):
super(MultiHeadAttention, self).__init__()
self.config = config
self.num_heads = config.num_heads
self.num_heads_k = config.num_heads
self.linear_qkv = nn.Linear(config.hidden_dim, config.hidden_dim * 3, bias=False)
self.linear_out = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False)
self.num_global_heads = config.global_heads[idx]
self.local_attention = LocalAttention(config, config.num_heads - self.num_global_heads, idx)
self.idx = idx
if self.num_global_heads > 0:
self.global_attention = GlobalAttention(config, self.num_global_heads, idx)
assert (config.hidden_dim % config.num_heads == 0)
def forward(self, inp, cache=None, decoding=False, **kwargs):
qkv = self.linear_qkv(inp)
batch, length, _ = list(qkv.size())
qkv = qkv.view(batch, length, self.num_heads, -1).transpose(1, 2)
q, k, v = torch.split(qkv, self.config.head_dim, dim=-1)
out, _ = self.local_attention(q[:, self.num_global_heads:], k[:, self.num_global_heads:],
v[:, self.num_global_heads:], decoding=decoding, **kwargs)
if self.num_global_heads > 0:
out_, kwargs = self.global_attention(q[:, :self.num_global_heads], v[:, :self.num_global_heads], **kwargs)
out = torch.cat([out_, out], 1)
out = out.transpose(1, 2).contiguous()
out = self.linear_out(out.view(batch, length, self.config.hidden_dim))
return out, cache, kwargs
def batched_bincount(index, num_classes, dim=-1):
shape = list(index.size())
shape[dim] = num_classes
out = index.new_zeros(shape).float()
out.scatter_add_(dim, index, torch.ones_like(index).float())
return out.type(index.dtype)
def similarity(x, means):
return torch.einsum('bhld,hcd->bhlc', x, means)
def squared_distance(x, means): # L2 dist ^ 2
return 2 - 2 * similarity(x, means)
def dists_and_buckets(x, means):
dists = similarity(x, means)
dists, buckets = torch.max(dists, dim=-1)
return dists, buckets
def compute_se(x, means):
dists = squared_distance(x, means)
dists, buckets = torch.min(dists, dim=-1)
dists_dtype = dists.dtype
dists_float = dists.float()
return dists.sum(-1).type(dists_dtype)
def buckets_to_means(x, buckets, num_clusters):
b, h, l, d = list(x.size())
means = buckets.new_zeros(b, h, num_clusters, d).float() # hcd
means.scatter_add_(-2, buckets.unsqueeze(-1).expand(-1, -1, -1, d), x.float())
return F.normalize(means.sum(0, keepdim=True).type(x.dtype), dim=-1)
def kmeans_iter(x, means, num_clusters, compute_se_): # hard k-means single iter
dists, buckets = dists_and_buckets(x, means)
bins = batched_bincount(buckets, num_clusters).sum(0, keepdim=True)
zero_mask = bins.long() == 0
means_ = buckets_to_means(x, buckets, num_clusters)
means = torch.where(zero_mask.unsqueeze(-1), means, means_)
means = means.squeeze(0)
se = compute_se(x, means) if compute_se_ else None
return means, se
def kmeans(config, x, means, num_clusters, training=True, reset=False, compute_se=False):
max_iters = config.num_updates_train if training else config.num_updates_test
if reset:
max_iters = 10
means = x.transpose(0, 1).contiguous().view(x.size(1), -1, x.size(-1))
means = means[:, torch.randperm(means.size(1), device=x.device)[:num_clusters]]
for idx in range(max_iters):
means, se = kmeans_iter(x, means, num_clusters, compute_se)
dists = similarity(x, means)
return means, dists, se
# pick the k points closest to each centroid and sort their indices in an ascending order for causal attention
def distribution(config, dists, means):
indices = dists.topk(k=config.window_size, dim=-2)[1].sort(dim=-2)[0].transpose(-2, -1) # [b, h, c, l]
return indices, means
class GlobalAttention(nn.Module):
def __init__(self, config, num_heads, idx=0):
super(GlobalAttention, self).__init__()
self.config = config
self.num_heads = num_heads
self.window_size = self.config.global_window_size
self.num_clusters = config.num_clusters
std = 1 / np.sqrt(config.hidden_dim)
self.R = nn.Parameter(
torch.zeros(self.window_size, num_heads, config.head_dim,
device=self.config.device).normal_(0, std))
self.register_buffer('means', torch.zeros(num_heads, config.num_clusters, config.head_dim))
self.register_buffer('init', torch.ones(1)) # 1 if the model was just initialized
def forward(self, qk, v, decoding=False, **kwargs):
with torch.no_grad():
k = F.normalize(qk, dim=-1)
if self.init == 1:
reset = True
self.init.copy_(torch.zeros_like(self.init))
else:
reset = False
means, dists, se = kmeans(self.config, k, self.means, self.num_clusters, training=self.training, reset=reset)
indices, means = distribution(self.config, dists, means)
self.means.copy_(means)
# TODO: do something here to make more k_i's to be included ... e.g. increase num_clusters for the missing ones?
indices = indices.contiguous().view(*indices.size()[:2], -1) # [b, h, l]
return self.attention(qk, v, indices, **kwargs)
def attention(self, qk, v, indices, **kwargs):
num_clusters = qk.size(-2) // self.window_size
batch_size, h_q, h_v, dim = qk.size(0), qk.size(1), v.size(1), qk.size(-1)
def batched_index_select(values, indices):
last_dim = values.shape[-1]
return values.gather(2, indices[:, :, :, None].expand(-1, -1, -1, last_dim))
qk = batched_index_select(qk, indices)
v = batched_index_select(v, indices)
qk = torch.reshape(qk,
(batch_size, h_q, num_clusters, self.window_size, qk.shape[-1])) # [b, h, nc, cs, d]
v = torch.reshape(v, (batch_size, h_v, num_clusters, self.window_size, v.shape[-1]))
q = qk
k = F.normalize(qk, dim=-1)
# Relative attention
k_part = torch.einsum('bhnid,bhnjd->bhnij', q, k)
tmp = torch.einsum('bhnid,jhd->bhnij', q, self.R)
wr_part = shift(tmp)
dots = k_part + wr_part
# Mask out attention to self except when no other targets are available.
pre_mask = torch.ones(self.window_size, self.window_size, device=self.config.device).byte().triu_(1)
dots += (pre_mask.float() * (-1e9)).type(dots.dtype)
self_mask = torch.eye(self.window_size, device=dots.device)
dots += (self_mask.float() * (-1e3)).type(dots.dtype)
dots = F.softmax(dots, -1) # [b, h, c, q, k]
dots = F.dropout(dots, p=self.config.dropout_prob, training=self.training)
bo = torch.einsum('bhcij,bhcjd->bhcid', dots, v)
so = torch.reshape(bo, (batch_size, h_q, -1, bo.shape[-1]))
out = scatter(so.float(), indices.unsqueeze(-1).expand_as(so), -2, dim_size = self.config.seqlen, reduce='mean').type(so.dtype)
return out, kwargs
class LocalAttention(nn.Module):
def __init__(self, config, num_heads, idx=0):
super(LocalAttention, self).__init__()
self.config = config
std = math.sqrt(1 / config.hidden_dim)
self.R = nn.Parameter(
torch.zeros(config.window_size, num_heads, config.head_dim,
device=self.config.device).normal_(0, std))
self.idx = idx
def forward(self, q, k, v, decoding=False, **kwargs):
b_q, h_q, t_q, dim_q = list(q.size())
b_k, h_k, t_k, dim_k = list(k.size())
window_size = self.config.window_size
q = q.view(b_q, h_q, -1, window_size // 2, dim_q) #
k = k.view(b_k, h_k, -1, window_size // 2, dim_k) #
v = v.view(b_k, h_k, -1, window_size // 2, dim_k) #
def f(x):
x_extra = F.pad(x[:, :, :-1, ...], pad=(0, 0, 0, 0, 1, 0))
return torch.cat([x_extra, x], dim=3)
k = f(k)
v = f(v)
k_part = torch.einsum('bhcqd,bhckd->bhcqk', q, k)
tmp = torch.einsum('bhcqd,khd->bhcqk', q, self.R)
wr_part = shift(tmp)
qk = k_part + wr_part
qk *= dim_q ** -0.5
pre_mask = torch.ones(window_size // 2, window_size, device=self.config.device).byte().triu_(
window_size // 2 + 1)
mask = (pre_mask.float() * (-1e9)).type(qk.dtype)
qk += mask
sm_qk = F.softmax(qk, dim=-1)
sm_qk = F.dropout(sm_qk, p=self.config.dropout_prob, training=self.training)
o = torch.einsum('bhcqk,bhckd->bhcqd', sm_qk, v)
o = o.view(b_q, h_q, t_q, dim_q)
return o, kwargs
dev = torch.device('cuda:0')
print(torch.cuda.get_device_name(0))
batch_size, length, dim = 1, 4096, 512
input = torch.randn(batch_size, length, dim, dtype=torch.float32, device=dev)
config = Config(dim, length, dev)
attention_layer = MultiHeadAttention(config, idx=0).to(dev)
attention_layer(input)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment