Skip to content

Instantly share code, notes, and snippets.

@guillefix
Created May 17, 2021 01:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save guillefix/8000aaf32afdae2496eedf87e829057d to your computer and use it in GitHub Desktop.
Save guillefix/8000aaf32afdae2496eedf87e829057d to your computer and use it in GitHub Desktop.
import sys
import os
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.abspath(os.path.join(THIS_DIR, os.pardir))
sys.path.append(ROOT_DIR)
sys.path.append(THIS_DIR)
import torch
import torch.nn as nn
import torch.nn.functional as F
import uuid
import numpy as np
from functools import partial
from torch.nn import TransformerEncoder, TransformerEncoderLayer, Transformer
#from models.x_transformers import ContinuousTransformerWrapper, Decoder, Encoder, AutoregressiveWrapper
from x_transformers import ContinuousTransformerWrapper, Decoder, Encoder, AutoregressiveWrapper
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
def group_dict_by_key(cond, d):
return_val = [dict(),dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000, device=None):
super(PositionalEncoding, self).__init__()
self.device = device
self.dropout = nn.Dropout(p=dropout)
self.lpe = nn.Embedding(max_len+1, d_model)
# self.weight = None
self.indices = torch.arange(max_len).unsqueeze(1) + 1
if device is not None:
self.indices = self.indices.to(self.device)
def init_weights(self):
initrange = 0.1
self.lpe.weight.data.uniform_(-initrange, initrange)
def forward(self, x, indices = None):
np.save(str(uuid.uuid4())+".np",self.lpe.weight.data.cpu().numpy())
if indices is None:
indices = self.indices[:x.size(0),:]
indices = self.dropout(indices)
x = x + self.lpe(indices)
return self.dropout(x)
class LearnedPositionalEncoding(nn.Module): # emm this isn't learned lol
def __init__(self, d_model, dropout=0.1, max_len=5000, device=None):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term1 = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
div_term2 = torch.exp(torch.arange(0, (d_model//2)*2, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term1)
pe[:, 1::2] = torch.cos(position * div_term2)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
# print(x.shape)
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
class BasicTransformerModel(nn.Module):
def __init__(self, dout, dinp, nhead, dhid, nlayers, dropout=0.5,device=None,use_pos_emb=False,input_length=0,use_x_transformers=False,opt=None):
super(BasicTransformerModel, self).__init__()
self.device = device
self.model_type = 'Transformer'
self.use_x_transformers = use_x_transformers
if not use_x_transformers:
self.encoder1 = nn.Linear(dinp, dhid)
#self.pos_encoder = PositionalEncoding(dhid, dropout, device=self.device)
encoder_layers = TransformerEncoderLayer(dhid, nhead, dhid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
# self.encoder = nn.Embedding(ntoken, dinp)
self.dinp = dinp
self.dhid = dhid
self.decoder = nn.Linear(dhid, dout)
self.use_pos_emb = use_pos_emb
if use_pos_emb:
assert input_length > 0
self.pos_emb = nn.Parameter((torch.zeros(input_length, input_length)))
# self.pos_emb = nn.Parameter((torch.eye(input_length, input_length)))
# self.pos_emb = nn.Parameter((torch.randn(input_length, input_length))/np.sqrt(dinp))
self.init_weights()
#self.pos_encoder.init_weights()
else:
self.model = ContinuousTransformerWrapper(
dim_in = dinp,
dim_out = dout,
max_seq_len = 1024,
use_pos_emb = use_pos_emb,
attn_layers = Encoder(
dim = dhid,
depth = nlayers,
heads = nhead,
rotary_pos_emb = opt.use_rotary_pos_emb,
#rel_pos_bias = True
)
)
def generate_square_subsequent_mask(self, sz, prefix_length = 1):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask[:,:prefix_length] = 1
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def init_weights(self):
initrange = 0.1
# self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, src, src_mask=None):
if not self.use_x_transformers:
# import pdb;pdb.set_trace()
src = self.encoder1(src)
#src *= math.sqrt(self.dhid)
#src = self.pos_encoder(src)
#src /= math.sqrt(self.dhid)
# print(src)
# print(torch.mm(src[:,0,:],src[:,0,:].T))
if self.use_pos_emb:
#print(self.pos_emb)
#src_mask += self.pos_emb
if src_mask is not None:
output = self.transformer_encoder(src, src_mask + self.pos_emb)
else:
output = self.transformer_encoder(src, self.pos_emb)
#output = self.transformer_encoder(src, self.pos_emb)
else:
if src_mask is not None:
output = self.transformer_encoder(src, src_mask)
else:
output = self.transformer_encoder(src)
#output = self.transformer_encoder(src)
output = self.decoder(output)
return output
else:
assert src_mask == None
src = src.permute(1,0,2)
mask = torch.ones(src.shape[0], src.shape[1]).bool().cuda()
output = self.model(src, mask = mask)
# output = self.model(src.permute(1,0,2))
return output.permute(1,0,2)
class EncDecTransformerModel(nn.Module):
def __init__(self, dout, src_d, tgt_d, nhead, dhid, nlayers, dropout=0.5,device=None,use_pos_emb=False,src_length=0,tgt_length=0,use_x_transformers=False,opt=None):
super(EncDecTransformerModel, self).__init__()
self.device = device
self.model_type = 'Transformer'
self.use_x_transformers = use_x_transformers
self.encoder1 = nn.Linear(src_d, dhid)
self.encoder2 = nn.Linear(tgt_d, dhid)
if not use_x_transformers:
self.transformer = Transformer(d_model=dhid, nhead=nhead, num_encoder_layers=nlayers, num_decoder_layers=nlayers, dropout=0, activation="relu")
#enc_layer = TransformerEncoderLayer(d_model=dhid, nhead=nhead, dropout=0, activation="relu")
#self.transformerEnc = TransformerEncoder(enc_layer, nlayers)
else:
self.transformer = EncDecXTransformer(enc_dim_in=src_d, enc_dim_out=tgt_d, dec_din_in=tgt_d, edec_dim_out=dout, enc_dim=dhid, dec_dim=dhid, nc_heads=nhead, dec_heads=nhead, enc_depth=nlayers, dec_depth=nlayers, enc_dropout=dropout, dec_dropout=dropout, enc_max_seq_len=1024, dec_max_seq_len=1024)
#xdecoder = Decoder(dim=dhid, depth=nlayers, heads=nhead, cross_attend=True)
#self.transformer = Transformer(d_model=dhid, nhead=nhead, num_encoder_layers=nlayers, num_decoder_layers=nlayers, dropout=dropout, activation="gelu", custom_decoder=xdecoder)
#self.transformer = Transformer(d_model=dhid, nhead=nhead, num_encoder_layers=nlayers, num_decoder_layers=nlayers)
# self.encoder = nn.Embedding(ntoken, dinp)
self.src_d = src_d
self.tgt_d = tgt_d
self.dhid = dhid
self.decoder = nn.Linear(dhid, dout)
self.use_pos_emb = use_pos_emb
if use_pos_emb:
assert src_length > 0
assert tgt_length > 0
self.src_pos_emb = nn.Parameter((torch.zeros(src_length, src_length)))
self.tgt_pos_emb = nn.Parameter((torch.zeros(tgt_length, tgt_length)))
if not use_x_transformers:
tgt_mask = self.generate_square_subsequent_mask(tgt_length)
else:
tgt_mask = self.generate_square_subsequent_mask_bool(tgt_length)
self.register_buffer("tgt_mask", tgt_mask)
#a = torch.randn(32,3,512)
#b = torch.randn(32,3,512)
#self.register_buffer('a', a)
#self.register_buffer('b', b)
self.init_weights()
def generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def generate_square_subsequent_mask_bool(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1).bool()
return mask
def init_weights(self):
initrange = 0.1
# self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, src, tgt):
if not self.use_x_transformers:
# import pdb;pdb.set_trace()
src = self.encoder1(src)
tgt = self.encoder2(tgt)
tgt_mask = self.tgt_mask[:tgt.shape[0], :tgt.shape[0]]
if self.use_pos_emb:
tgt_pos_emb = self.tgt_pos_emb[:tgt.shape[0], :tgt.shape[0]]
# import pdb;pdb.set_trace()
output = self.transformer(src=src, tgt=tgt, src_mask=self.src_pos_emb, tgt_mask=tgt_pos_emb+tgt_mask)
#output = self.transformer(src=src, tgt=tgt)
else:
output = self.transformer(src=src, tgt=tgt, tgt_mask=tgt_mask)
#output = self.transformer(src=self.a, tgt=self.b)
#output = self.transformerEnc(src=self.a)
output = self.decoder(output)
return output
else:
src = self.encoder1(src)
tgt = self.encoder2(tgt)
#tgt_mask = self.tgt_mask[:tgt.shape[0], :tgt.shape[0]]
# if self.use_pos_emb:
# tgt_pos_emb = self.tgt_pos_emb[:tgt.shape[0], :tgt.shape[0]]
# # import pdb;pdb.set_trace()
# output = self.transformer(src=src.permute(1,2,0), tgt=tgt.permute(1,2,0), src_mask=self.src_pos_emb, tgt_mask=tgt_pos_emb+tgt_mask)
# #output = self.transformer(src=src, tgt=tgt)
# else:
# output = self.transformer(src=src.permute(1,0,2), tgt=tgt.permute(1,0,2), tgt_mask=tgt_mask)
# output = self.transformer(src=src, tgt=tgt, tgt_mask=tgt_mask)
output = self.transformer(src=src, tgt=tgt)
#output = self.transformer(src=self.a, tgt=self.b)
#output = self.transformer(src=tgt, tgt=tgt) #hmm thats an interesting way of residual attention
output = self.decoder(output)
return output
#return
class EncDecXTransformer(nn.Module):
def __init__(
self,
*,
# dim,
tie_token_emb = False,
**kwargs
):
super().__init__()
enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)
# import pdb;pdb.set_trace()
# assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'
# enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs)
enc_transformer_kwargs = pick_and_pop(['max_seq_len'], enc_kwargs)
# enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None)
# dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs)
dec_transformer_kwargs = pick_and_pop(['max_seq_len'], dec_kwargs)
self.encoder = ContinuousTransformerWrapper(
**enc_transformer_kwargs,
attn_layers = Encoder(**enc_kwargs)
)
self.decoder = ContinuousTransformerWrapper(
**dec_transformer_kwargs,
attn_layers = Decoder(cross_attend = True, **dec_kwargs)
)
if tie_token_emb:
self.decoder.token_emb = self.encoder.token_emb
# self.decoder = AutoregressiveWrapper(self.decoder)
@torch.no_grad()
def generate(self, seq_in, seq_out_start, seq_len, src_mask = None):
encodings = self.encoder(seq_in, return_embeddings = True, mask = src_mask)
return self.decoder.generate(seq_out_start, seq_len, context = encodings, context_mask = src_mask)
def forward(self, src, tgt, src_mask = None, tgt_mask = None):
enc = self.encoder(src, mask = src_mask, return_embeddings = True)
#out = self.decoder(tgt, context = enc, mask = tgt_mask, context_mask = src_mask)
out = self.decoder(tgt, context = enc)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment