Created
August 13, 2020 09:41
-
-
Save AranKomat/c3ed7cb433d2595c6ff31feb436178c0 to your computer and use it in GitHub Desktop.
Incomplete implmenetation of extended MARGE architecture
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn import Parameter | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.nn.init as init | |
import math | |
import numpy as np | |
from torch.autograd import Function | |
#from torch_scatter import scatter | |
from apex.normalization import FusedLayerNorm as LayerNorm | |
from copy import copy | |
from time import time, sleep | |
from .utils import update_dict2obj | |
def initialize(m): | |
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d): | |
init.xavier_normal_(m.weight.data) | |
if m.bias is not None: | |
m.bias.data.zero_() | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model, max_len=5000): | |
super(PositionalEncoding, self).__init__() | |
pe = torch.zeros(max_len, d_model) | |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0) | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
return x + self.pe[:, :x.size(1)] | |
class Normalization(nn.Module): | |
def __init__(self, config): | |
super(Normalization, self).__init__() | |
self.config = config | |
self.hidden_dim = config.hidden_dim | |
if config.powernorm: | |
self.norm = PowerNorm() | |
else: | |
self.norm = LayerNorm(self.hidden_dim) | |
def forward(self, x, **kwargs): | |
if self.powernorm: | |
return self.norm(x) | |
else: | |
class Softmax(nn.Module): | |
def __init__(self, config): | |
super(Softmax, self).__init__() | |
self.config = config | |
self.hidden_dim = config.hidden_dim | |
if config.adaptive: | |
self.softmax = nn.AdaptiveLogSoftmaxWithLoss(self.hidden_dim, config.vocab_size, config.cutoffs) | |
else: | |
self.linear = nn.Linear(self.hidden_dim, config.vocab_size, bias=False) | |
def forward(self, x, tgt=None, decoding=False, **kwargs): | |
if self.config.adaptive: | |
tgt = tgt.contiguous().view(-1) | |
if decoding: # TODO: incomplete here | |
return self.softmax.log_prob(x, tgt) | |
else: | |
loss = [self.softmax(x.view(-1, self.hidden_dim), tgt)[1]] | |
if self.config.loss_end is not None: | |
tmp = self.softmax(x.view(-1, self.hidden_dim), tgt)[0].view(-1, self.config.seqlen) | |
loss += [tmp[:, n-1::n].mean().item() for n in self.config.loss_end] | |
return loss | |
else: | |
x = self.linear(x) | |
#x = F.log_softmax(x, -1) | |
return x | |
class Decoder(nn.Module): | |
def __init__(self, config): | |
super(Decoder, self).__init__() | |
self.config = config | |
self.decoder = Transformer(config, config.decoder_config) | |
self.softmax = Softmax(config) | |
self.norm = Normalization(config) | |
self.emb = nn.Embedding(config.vocab_size, config.hidden_dim) | |
def forward(self, inp, tgt=None, **kwargs): | |
x = self.emb(inp) | |
x, kwargs = self.decoder(x, **kwargs) | |
x = self.norm(x, **kwargs) | |
return self.softmax(x, tgt=tgt, **kwargs), kwargs | |
class Encoder(nn.Module): | |
def __init__(self, config): | |
super(Decoder, self).__init__() | |
self.config = config | |
self.encoder = Transformer(config, config.encoder_config, enc=True) | |
self.emb = nn.Embedding(config.vocab_size, config.hidden_dim) | |
def forward(self, inp, **kwargs): | |
x = self.emb(inp) | |
_, kwargs = self.encoder(x, **kwargs) | |
return kwargs | |
def compute_embedding(x): return F.normalize(x[:, 0], dim=-1) | |
class EncoderDecoder(nn.Module): | |
def __init__(self, config): | |
super(EncoderDecoder, self).__init__() | |
self.config = config | |
self.encoder = Encoder(config) | |
self.decoder = Decoder(config) | |
def init_cache(self, **kwargs): | |
kwargs['enc_cache'] = [None for range(self.config.encoder_config.depth)] | |
kwargs['emb'] = None | |
return kwargs | |
# this method is only for getting the embedding for knn graph construction | |
# for training, you don't need to use this | |
def embed(self, inp, **kwargs): | |
for block in self.encoder.encoder.blocks[:self.config.emb_depth]: | |
x, _ = block(x, **kwargs) | |
return compute_embedding(x) | |
# input must be shaped as [(b*s), window_size] for enc-dec and causal dec | |
# but [b, seqlen] for full and local | |
# (s) should be reoreded so as to preserve causality along s for cross decoder | |
#TODO: fix full attn for fair comparison | |
def forward(self, inp, **kwargs): | |
self.init_cache(**kwargs) | |
kwargs = self.encoder(inp, **kwargs) | |
out, kwargs = self.decoder(inp, **kwargs) | |
self.init_cache(**kwargs) | |
return out, kwargs | |
class Transformer(nn.Module): | |
def __init__(self, config, module_config): | |
super(Transformer, self).__init__() | |
self.config = config | |
self.module_config = module_config | |
self.blocks = nn.ModuleList([Block(config, layer_idx, module_config) for layer_idx in range(module_config.depth))]) | |
if config.abs: | |
self.pe = PositionalEncoding(config.hidden_dim) | |
def forward(self, x, **kwargs): | |
if self.config.abs: | |
x = self.pe(x) | |
x = F.dropout(x, p=self.config.dropout_prob, training=self.training) | |
for layer_idx, block in enumerate(self.blocks): | |
x, kwargs = block(x, **kwargs) | |
if self.module_config.enc: | |
kwargs['enc_cache'][layer_idx] = x | |
if layer_idx == self.config.emb_depth - 1: | |
kwargs['emb'] = compute_embedding(x) | |
return x, kwargs | |
class Layer(nn.Module): | |
def __init__(self, config, layer_idx, module_config): | |
super(Layer, self).__init__() | |
self.config = config | |
self.blocks = nn.ModuleList([Block(config, layer_idx, module) for module in range(module_config.modules)]) | |
def forward(self, x, **kwargs): | |
for idx, block in enumerate(self.blocks): | |
x, kwargs = subblock(x, **kwargs) | |
return x, kwargs | |
class Block(nn.Module): | |
def __init__(self, config, layer_idx, module): | |
super(Block, self).__init__() | |
self.config = config | |
self.norm = Normalization(config) | |
print(module) | |
if module in ['c', 'e', 'd', 'l']: | |
module_config = {'causal': True, 'cross': False, 'local': False, 'decoder_cross': False} | |
if module == 'c': # enc-dec attn of decoder | |
module_config = update_dict2obj(module_config, causal = False, cross = True) | |
elif module == 'e': # encoder self-attn | |
module_config = update_dict2obj(module_config, causal = False) | |
elif module == 'd': # decoder self-attn | |
module_config = update_dict2obj(module_config, decoder_cross = config.decoder_cross) | |
elif module == 'l': # local decoder self-attn | |
module_config = update_dict2obj(module_config, local = True) | |
self.module = MultiHeadAttention(config, layer_idx, module_config) | |
elif module == 'f': # ffn | |
self.module = FFN(config, layer_idx) | |
elif module == 'm': # moe | |
self.module = MoE(config, layer_idx) | |
else: | |
raise NotImplementedError | |
def forward(self, x, **kwargs): | |
res = x | |
x = self.norm(x) | |
x, kwargs = self.module(x, **kwargs) | |
x = F.dropout(x, p=self.config.dropout_prob, training=self.training) | |
x.add_(res) | |
return x, kwargs | |
class GLU(nn.Module): | |
def __init__(self): | |
super(GLU, self).__init__() | |
def forward(self, x): | |
x, v = x.chunk(2, dim=-1) | |
return F.gelu(x) * v | |
class FFN(nn.Module): | |
def __init__(self, config, layer_idx): | |
super(FFN, self).__init__() | |
self.config = config | |
inner_dim = config.d_ff | |
if config.activation == 'glu': | |
inner_dim2 = 8 * int(inner_dim / 12) | |
inner_dim1 = inner_dim2 * 2 | |
activation = GLU() | |
else: # relu | |
activation = nn.ReLU() | |
inner_dim1 = inner_dim2 = inner_dim | |
self.ffn = nn.Sequential(nn.Linear(config.hidden_dim, inner_dim1), activation, | |
nn.Dropout(p=self.config.dropout_prob), | |
nn.Linear(inner_dim2, config.hidden_dim)) | |
def forward(self, x, **kwargs): | |
return self.ffn(x), kwargs | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, config, layer_idx=0, module_config): | |
super(MultiHeadAttention, self).__init__() | |
self.config = config | |
self.num_heads = config.num_heads | |
self.hidden_dim = config.hidden_dim | |
self.q_dim = self.num_heads * config.head_dim | |
qkv_output_dim = 3 * self.q_dim if not module_config.cross else self.q_dim | |
self.linear_in = nn.Linear(self.hidden_dim, qkv_output_dim, bias=False) | |
self.linear_out = nn.Linear(self.q_dim, config.hidden_dim, bias=False) | |
if module_config.local: | |
self.attention = LocalAttention(config, layer_idx) | |
else: | |
self.attention = Attention(config, layer_idx, module_config) | |
self.layer_idx = layer_idx | |
self.module_config = module_config | |
def forward(self, inp, **kwargs): | |
qkv = self.linear_in(inp) | |
batch, length, _ = list(qkv.size()) | |
qkv = qkv.view(batch, length, -1, self.config.head_dim).transpose(1, 2) | |
if not self.module_config.cross: | |
q, k, v = torch.split(qkv, self.num_heads, dim=1) | |
else: | |
q = qkv | |
k = v = kwargs['enc_cache'][self.layer_idx] | |
out, _ = self.attention(q, k, v, **kwargs) | |
out = self.linear_out(out.transpose(1, 2).contiguous().view(batch, length, self.q_dim)) | |
return out, kwargs | |
def shift(x, causal=True): | |
# x = [*, t_q, t_k] | |
t_q = x.size()[-2] | |
zero_pad = x.new_zeros(*x.size()[:-1], x.size(-2)) | |
x = torch.cat([x, zero_pad], -1) | |
l = x.size(-1) | |
x = x.view(*x.size()[:-2], -1) | |
zero_pad = x.new_zeros(*x.size()[:-1], -x.size(-1) % (l - 1)) | |
tmp = torch.cat([x, zero_pad], -1).view(*x.size()[:-1], -1, l - 1)[..., :t_q] | |
if not causal: | |
return tmp[..., t_q - 1: 2 * t_q - 1] | |
else: | |
return tmp[..., t_q - 1:] | |
class Attention(nn.Module): | |
def __init__(self, config, layer_idx, module_config): | |
super(Attention, self).__init__() | |
self.config = config | |
std = math.sqrt(1 / config.hidden_dim) | |
self.module_config = module_config | |
self.cross = module_config.cross | |
self.decoder_cross = module_config.decoder_cross | |
self.window_size = config.window_size | |
if not config.abs: | |
R_len = self.window_size if module_config.causal else 2 * self.window_size - 1 | |
self.R = nn.Parameter( | |
torch.zeros(R_len, config.num_heads, config.head_dim).normal_(0, std)) | |
if self.cross or self.decoder_cross: | |
self.beta = nn.Parameter(torch.ones(1, 1, 1, config.num_heads)) | |
def forward(self, q, k, v, **kwargs): | |
cross = self.cross or self.decoder_cross | |
if cross: | |
q, k, v, emb = map(lambda x: x.view(-1, self.config.subbatch_size, *x.size()[1:]), (q, k, v, kwargs['emb'])) #[b, s, h, l, d] | |
v = v.transpose(1, 2).view(*v.size()[:2], -1, v.size(-1)) | |
emb_attn = torch.einsum('bsd,btd->bst', emb, emb).unsqueeze(-1) * self.beta | |
if self.cross: | |
same_samples = kwargs['sample_idx'].unsqueeze(-2) == kwargs['sample_idx'].unsqueeze(-1) | |
noncausals = kwargs['segment_idx'].unsqueeze(-2) <= kwargs['segment_idx'].unsqueeze(-1) | |
mask = (samples_samples * noncausals) # bst | |
nan_mask = mask.all(dim=-1) # bs | |
emb_attn += (mask.float() * (-1e9)).astype(emb_attn.type()).unsqueeze(1) | |
qk = torch.einsum('bshqd,bthkd->bsthqk', q, k) + emb_attn.unsqueeze(-1).unsqueeze(-1) | |
else: | |
qk = torch.einsum('bhqd,bhkd->bhqk', q, k) | |
# relative positional encoding | |
if not self.config.abs: | |
pos_enc = shift(torch.einsum('...hqd,khd->...hqk', q, self.R), causal=self.module_config.causal) #bhqk/bshqk | |
if cross: | |
qk += pos_enc.unsqueeze(2) #bsthqk | |
else: | |
qk += pos_enc | |
if cross: | |
qk = qk.permute(0, 3, 1, 4, 2, 5).reshape_as(v) | |
# causal masking | |
if self.module_config.causal: | |
mask = torch.ones(self.seqlen, self.seqlen, device=self.config.device).byte().triu_(1) | |
qk += (mask.float() * (-1e9)).astype(qk.type()) | |
qk *= self.config.head_dim ** -0.5 | |
sm_qk = F.softmax(qk, dim=-1) | |
# deals with the rows that are all masked | |
if self.cross and nan_mask.any(): | |
nan_mask = nan_mask.unsqueeze(-1).expand(-1, -1, self.config.window_size).view(nan_mask.size(0), -1) | |
sm_qk = torch.where(nan_mask.unsqueeze(1).unsqueeze(-1), 0, sm_qk) | |
sm_qk = F.dropout(sm_qk, p=self.config.dropout_prob, training=self.training) | |
o = torch.einsum('bhqk, bhkd -> bhqd', sm_qk, v) | |
if cross: | |
o = o.contiguous().view(*o.size()[:2], self.config.batch_size, -1, o.size(-1)).transpose(1, 2) # bshqd | |
o = o.contiguous().view(-1, *o.size()[2:]) | |
return o, kwargs | |
class LocalAttention(nn.Module): | |
def __init__(self, config, layer_idx, module_config): | |
super(LocalAttention, self).__init__() | |
self.config = config | |
std = math.sqrt(1 / config.hidden_dim) | |
self.window_size = config.window_size | |
self.R = nn.Parameter( | |
torch.zeros(self.window_size, config.num_heads, config.head_dim).normal_(0, std)) | |
def forward(self, q, k, v, **kwargs): | |
b_q, h_q, t_q, dim_q = list(q.size()) | |
b_k, h_k, t_k, dim_k = list(k.size()) | |
window_size = self.config.window_size | |
q = q.view(b_q, h_q, -1, window_size // 2, dim_q) # | |
k = k.view(b_k, h_k, -1, window_size // 2, dim_k) # | |
v = v.view(b_k, h_k, -1, window_size // 2, dim_k) # | |
def f(x): | |
x_extra = F.pad(x[:, :, :-1, ...], pad=(0, 0, 0, 0, 1, 0)) | |
return torch.cat([x_extra, x], dim=3) | |
k = f(k) | |
v = f(v) | |
k_part = torch.einsum('bhcqd,bhckd->bhcqk', q, | |
k) if not self.config.multi_query else torch.einsum('bhcqd,bckd->bhcqk', q, | |
k.squeeze(1)) | |
tmp = torch.einsum('bhcqd,khd->bhcqk', q, self.R) | |
wr_part = shift(tmp) | |
qk = k_part + wr_part | |
qk *= dim_q ** -0.5 | |
pre_mask = torch.ones(window_size // 2, window_size, device=self.config.device).byte().triu_( | |
window_size // 2 + 1) | |
mask = (pre_mask.float() * (-1e9)).astype(qk.type()) | |
qk += mask | |
sm_qk = F.softmax(qk, dim=-1) | |
sm_qk = F.dropout(sm_qk, p=self.config.dropout_prob, training=self.training) | |
o = torch.einsum('bhcqk,bhckd->bhcqd', sm_qk, v) if not self.config.multi_query else torch.einsum( | |
'bhcqk,bckd->bhcqd', sm_qk, v.squeeze(1)) | |
o = o.contiguous().view(b_q, h_q, t_q, dim_q) | |
return o, kwargs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment