Skip to content

Instantly share code, notes, and snippets.

@AranKomat
Created August 13, 2020 09:41
Show Gist options
  • Save AranKomat/c3ed7cb433d2595c6ff31feb436178c0 to your computer and use it in GitHub Desktop.
Save AranKomat/c3ed7cb433d2595c6ff31feb436178c0 to your computer and use it in GitHub Desktop.
Incomplete implmenetation of extended MARGE architecture
import torch
from torch.nn import Parameter
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import math
import numpy as np
from torch.autograd import Function
#from torch_scatter import scatter
from apex.normalization import FusedLayerNorm as LayerNorm
from copy import copy
from time import time, sleep
from .utils import update_dict2obj
def initialize(m):
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d):
init.xavier_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:, :x.size(1)]
class Normalization(nn.Module):
def __init__(self, config):
super(Normalization, self).__init__()
self.config = config
self.hidden_dim = config.hidden_dim
if config.powernorm:
self.norm = PowerNorm()
else:
self.norm = LayerNorm(self.hidden_dim)
def forward(self, x, **kwargs):
if self.powernorm:
return self.norm(x)
else:
class Softmax(nn.Module):
def __init__(self, config):
super(Softmax, self).__init__()
self.config = config
self.hidden_dim = config.hidden_dim
if config.adaptive:
self.softmax = nn.AdaptiveLogSoftmaxWithLoss(self.hidden_dim, config.vocab_size, config.cutoffs)
else:
self.linear = nn.Linear(self.hidden_dim, config.vocab_size, bias=False)
def forward(self, x, tgt=None, decoding=False, **kwargs):
if self.config.adaptive:
tgt = tgt.contiguous().view(-1)
if decoding: # TODO: incomplete here
return self.softmax.log_prob(x, tgt)
else:
loss = [self.softmax(x.view(-1, self.hidden_dim), tgt)[1]]
if self.config.loss_end is not None:
tmp = self.softmax(x.view(-1, self.hidden_dim), tgt)[0].view(-1, self.config.seqlen)
loss += [tmp[:, n-1::n].mean().item() for n in self.config.loss_end]
return loss
else:
x = self.linear(x)
#x = F.log_softmax(x, -1)
return x
class Decoder(nn.Module):
def __init__(self, config):
super(Decoder, self).__init__()
self.config = config
self.decoder = Transformer(config, config.decoder_config)
self.softmax = Softmax(config)
self.norm = Normalization(config)
self.emb = nn.Embedding(config.vocab_size, config.hidden_dim)
def forward(self, inp, tgt=None, **kwargs):
x = self.emb(inp)
x, kwargs = self.decoder(x, **kwargs)
x = self.norm(x, **kwargs)
return self.softmax(x, tgt=tgt, **kwargs), kwargs
class Encoder(nn.Module):
def __init__(self, config):
super(Decoder, self).__init__()
self.config = config
self.encoder = Transformer(config, config.encoder_config, enc=True)
self.emb = nn.Embedding(config.vocab_size, config.hidden_dim)
def forward(self, inp, **kwargs):
x = self.emb(inp)
_, kwargs = self.encoder(x, **kwargs)
return kwargs
def compute_embedding(x): return F.normalize(x[:, 0], dim=-1)
class EncoderDecoder(nn.Module):
def __init__(self, config):
super(EncoderDecoder, self).__init__()
self.config = config
self.encoder = Encoder(config)
self.decoder = Decoder(config)
def init_cache(self, **kwargs):
kwargs['enc_cache'] = [None for range(self.config.encoder_config.depth)]
kwargs['emb'] = None
return kwargs
# this method is only for getting the embedding for knn graph construction
# for training, you don't need to use this
def embed(self, inp, **kwargs):
for block in self.encoder.encoder.blocks[:self.config.emb_depth]:
x, _ = block(x, **kwargs)
return compute_embedding(x)
# input must be shaped as [(b*s), window_size] for enc-dec and causal dec
# but [b, seqlen] for full and local
# (s) should be reoreded so as to preserve causality along s for cross decoder
#TODO: fix full attn for fair comparison
def forward(self, inp, **kwargs):
self.init_cache(**kwargs)
kwargs = self.encoder(inp, **kwargs)
out, kwargs = self.decoder(inp, **kwargs)
self.init_cache(**kwargs)
return out, kwargs
class Transformer(nn.Module):
def __init__(self, config, module_config):
super(Transformer, self).__init__()
self.config = config
self.module_config = module_config
self.blocks = nn.ModuleList([Block(config, layer_idx, module_config) for layer_idx in range(module_config.depth))])
if config.abs:
self.pe = PositionalEncoding(config.hidden_dim)
def forward(self, x, **kwargs):
if self.config.abs:
x = self.pe(x)
x = F.dropout(x, p=self.config.dropout_prob, training=self.training)
for layer_idx, block in enumerate(self.blocks):
x, kwargs = block(x, **kwargs)
if self.module_config.enc:
kwargs['enc_cache'][layer_idx] = x
if layer_idx == self.config.emb_depth - 1:
kwargs['emb'] = compute_embedding(x)
return x, kwargs
class Layer(nn.Module):
def __init__(self, config, layer_idx, module_config):
super(Layer, self).__init__()
self.config = config
self.blocks = nn.ModuleList([Block(config, layer_idx, module) for module in range(module_config.modules)])
def forward(self, x, **kwargs):
for idx, block in enumerate(self.blocks):
x, kwargs = subblock(x, **kwargs)
return x, kwargs
class Block(nn.Module):
def __init__(self, config, layer_idx, module):
super(Block, self).__init__()
self.config = config
self.norm = Normalization(config)
print(module)
if module in ['c', 'e', 'd', 'l']:
module_config = {'causal': True, 'cross': False, 'local': False, 'decoder_cross': False}
if module == 'c': # enc-dec attn of decoder
module_config = update_dict2obj(module_config, causal = False, cross = True)
elif module == 'e': # encoder self-attn
module_config = update_dict2obj(module_config, causal = False)
elif module == 'd': # decoder self-attn
module_config = update_dict2obj(module_config, decoder_cross = config.decoder_cross)
elif module == 'l': # local decoder self-attn
module_config = update_dict2obj(module_config, local = True)
self.module = MultiHeadAttention(config, layer_idx, module_config)
elif module == 'f': # ffn
self.module = FFN(config, layer_idx)
elif module == 'm': # moe
self.module = MoE(config, layer_idx)
else:
raise NotImplementedError
def forward(self, x, **kwargs):
res = x
x = self.norm(x)
x, kwargs = self.module(x, **kwargs)
x = F.dropout(x, p=self.config.dropout_prob, training=self.training)
x.add_(res)
return x, kwargs
class GLU(nn.Module):
def __init__(self):
super(GLU, self).__init__()
def forward(self, x):
x, v = x.chunk(2, dim=-1)
return F.gelu(x) * v
class FFN(nn.Module):
def __init__(self, config, layer_idx):
super(FFN, self).__init__()
self.config = config
inner_dim = config.d_ff
if config.activation == 'glu':
inner_dim2 = 8 * int(inner_dim / 12)
inner_dim1 = inner_dim2 * 2
activation = GLU()
else: # relu
activation = nn.ReLU()
inner_dim1 = inner_dim2 = inner_dim
self.ffn = nn.Sequential(nn.Linear(config.hidden_dim, inner_dim1), activation,
nn.Dropout(p=self.config.dropout_prob),
nn.Linear(inner_dim2, config.hidden_dim))
def forward(self, x, **kwargs):
return self.ffn(x), kwargs
class MultiHeadAttention(nn.Module):
def __init__(self, config, layer_idx=0, module_config):
super(MultiHeadAttention, self).__init__()
self.config = config
self.num_heads = config.num_heads
self.hidden_dim = config.hidden_dim
self.q_dim = self.num_heads * config.head_dim
qkv_output_dim = 3 * self.q_dim if not module_config.cross else self.q_dim
self.linear_in = nn.Linear(self.hidden_dim, qkv_output_dim, bias=False)
self.linear_out = nn.Linear(self.q_dim, config.hidden_dim, bias=False)
if module_config.local:
self.attention = LocalAttention(config, layer_idx)
else:
self.attention = Attention(config, layer_idx, module_config)
self.layer_idx = layer_idx
self.module_config = module_config
def forward(self, inp, **kwargs):
qkv = self.linear_in(inp)
batch, length, _ = list(qkv.size())
qkv = qkv.view(batch, length, -1, self.config.head_dim).transpose(1, 2)
if not self.module_config.cross:
q, k, v = torch.split(qkv, self.num_heads, dim=1)
else:
q = qkv
k = v = kwargs['enc_cache'][self.layer_idx]
out, _ = self.attention(q, k, v, **kwargs)
out = self.linear_out(out.transpose(1, 2).contiguous().view(batch, length, self.q_dim))
return out, kwargs
def shift(x, causal=True):
# x = [*, t_q, t_k]
t_q = x.size()[-2]
zero_pad = x.new_zeros(*x.size()[:-1], x.size(-2))
x = torch.cat([x, zero_pad], -1)
l = x.size(-1)
x = x.view(*x.size()[:-2], -1)
zero_pad = x.new_zeros(*x.size()[:-1], -x.size(-1) % (l - 1))
tmp = torch.cat([x, zero_pad], -1).view(*x.size()[:-1], -1, l - 1)[..., :t_q]
if not causal:
return tmp[..., t_q - 1: 2 * t_q - 1]
else:
return tmp[..., t_q - 1:]
class Attention(nn.Module):
def __init__(self, config, layer_idx, module_config):
super(Attention, self).__init__()
self.config = config
std = math.sqrt(1 / config.hidden_dim)
self.module_config = module_config
self.cross = module_config.cross
self.decoder_cross = module_config.decoder_cross
self.window_size = config.window_size
if not config.abs:
R_len = self.window_size if module_config.causal else 2 * self.window_size - 1
self.R = nn.Parameter(
torch.zeros(R_len, config.num_heads, config.head_dim).normal_(0, std))
if self.cross or self.decoder_cross:
self.beta = nn.Parameter(torch.ones(1, 1, 1, config.num_heads))
def forward(self, q, k, v, **kwargs):
cross = self.cross or self.decoder_cross
if cross:
q, k, v, emb = map(lambda x: x.view(-1, self.config.subbatch_size, *x.size()[1:]), (q, k, v, kwargs['emb'])) #[b, s, h, l, d]
v = v.transpose(1, 2).view(*v.size()[:2], -1, v.size(-1))
emb_attn = torch.einsum('bsd,btd->bst', emb, emb).unsqueeze(-1) * self.beta
if self.cross:
same_samples = kwargs['sample_idx'].unsqueeze(-2) == kwargs['sample_idx'].unsqueeze(-1)
noncausals = kwargs['segment_idx'].unsqueeze(-2) <= kwargs['segment_idx'].unsqueeze(-1)
mask = (samples_samples * noncausals) # bst
nan_mask = mask.all(dim=-1) # bs
emb_attn += (mask.float() * (-1e9)).astype(emb_attn.type()).unsqueeze(1)
qk = torch.einsum('bshqd,bthkd->bsthqk', q, k) + emb_attn.unsqueeze(-1).unsqueeze(-1)
else:
qk = torch.einsum('bhqd,bhkd->bhqk', q, k)
# relative positional encoding
if not self.config.abs:
pos_enc = shift(torch.einsum('...hqd,khd->...hqk', q, self.R), causal=self.module_config.causal) #bhqk/bshqk
if cross:
qk += pos_enc.unsqueeze(2) #bsthqk
else:
qk += pos_enc
if cross:
qk = qk.permute(0, 3, 1, 4, 2, 5).reshape_as(v)
# causal masking
if self.module_config.causal:
mask = torch.ones(self.seqlen, self.seqlen, device=self.config.device).byte().triu_(1)
qk += (mask.float() * (-1e9)).astype(qk.type())
qk *= self.config.head_dim ** -0.5
sm_qk = F.softmax(qk, dim=-1)
# deals with the rows that are all masked
if self.cross and nan_mask.any():
nan_mask = nan_mask.unsqueeze(-1).expand(-1, -1, self.config.window_size).view(nan_mask.size(0), -1)
sm_qk = torch.where(nan_mask.unsqueeze(1).unsqueeze(-1), 0, sm_qk)
sm_qk = F.dropout(sm_qk, p=self.config.dropout_prob, training=self.training)
o = torch.einsum('bhqk, bhkd -> bhqd', sm_qk, v)
if cross:
o = o.contiguous().view(*o.size()[:2], self.config.batch_size, -1, o.size(-1)).transpose(1, 2) # bshqd
o = o.contiguous().view(-1, *o.size()[2:])
return o, kwargs
class LocalAttention(nn.Module):
def __init__(self, config, layer_idx, module_config):
super(LocalAttention, self).__init__()
self.config = config
std = math.sqrt(1 / config.hidden_dim)
self.window_size = config.window_size
self.R = nn.Parameter(
torch.zeros(self.window_size, config.num_heads, config.head_dim).normal_(0, std))
def forward(self, q, k, v, **kwargs):
b_q, h_q, t_q, dim_q = list(q.size())
b_k, h_k, t_k, dim_k = list(k.size())
window_size = self.config.window_size
q = q.view(b_q, h_q, -1, window_size // 2, dim_q) #
k = k.view(b_k, h_k, -1, window_size // 2, dim_k) #
v = v.view(b_k, h_k, -1, window_size // 2, dim_k) #
def f(x):
x_extra = F.pad(x[:, :, :-1, ...], pad=(0, 0, 0, 0, 1, 0))
return torch.cat([x_extra, x], dim=3)
k = f(k)
v = f(v)
k_part = torch.einsum('bhcqd,bhckd->bhcqk', q,
k) if not self.config.multi_query else torch.einsum('bhcqd,bckd->bhcqk', q,
k.squeeze(1))
tmp = torch.einsum('bhcqd,khd->bhcqk', q, self.R)
wr_part = shift(tmp)
qk = k_part + wr_part
qk *= dim_q ** -0.5
pre_mask = torch.ones(window_size // 2, window_size, device=self.config.device).byte().triu_(
window_size // 2 + 1)
mask = (pre_mask.float() * (-1e9)).astype(qk.type())
qk += mask
sm_qk = F.softmax(qk, dim=-1)
sm_qk = F.dropout(sm_qk, p=self.config.dropout_prob, training=self.training)
o = torch.einsum('bhcqk,bhckd->bhcqd', sm_qk, v) if not self.config.multi_query else torch.einsum(
'bhcqk,bckd->bhcqd', sm_qk, v.squeeze(1))
o = o.contiguous().view(b_q, h_q, t_q, dim_q)
return o, kwargs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment