Skip to content

Instantly share code, notes, and snippets.

@guillefix
Created May 17, 2021 01:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save guillefix/c68ccf4a9f73a46afaced1164c4a14eb to your computer and use it in GitHub Desktop.
Save guillefix/c68ccf4a9f73a46afaced1164c4a14eb to your computer and use it in GitHub Desktop.
import torch
from torch import nn
from .transformer import BasicTransformerModel
from models import BaseModel
from .util.generation import autoregressive_generation_multimodal
class TransformerModel(BaseModel):
def __init__(self, opt):
super().__init__(opt)
opt=self.opt
input_mods = self.input_mods
output_mods = self.output_mods
dins = self.dins
douts = self.douts
input_lengths = self.input_lengths
self.input_mod_nets = []
self.output_mod_nets = []
self.module_names = []
for i, mod in enumerate(input_mods):
net = BasicTransformerModel(opt.dhid, dins[i], opt.nhead, opt.dhid, 2, opt.dropout, self.device, use_pos_emb=True, input_length=input_lengths[i], use_x_transformers=opt.use_x_transformers, opt=opt)
name = "_input_"+mod
setattr(self,"net"+name, net)
self.input_mod_nets.append(net)
self.module_names.append(name)
for i, mod in enumerate(output_mods):
net = BasicTransformerModel(douts[i], opt.dhid, opt.nhead, opt.dhid, opt.nlayers, opt.dropout, self.device, use_pos_emb=opt.use_pos_emb_output, input_length=sum(input_lengths), use_x_transformers=opt.use_x_transformers, opt=opt)
# net = BasicTransformerModel(douts[i], opt.dhid, opt.nhead, opt.dhid, opt.nlayers, opt.dropout, self.device, use_pos_emb=True, input_length=sum(input_lengths))
name = "_output_"+mod
setattr(self,"net"+name, net)
self.output_mod_nets.append(net)
self.module_names.append(name)
#This is feature creep. Will remove soon
# if self.opt.generate_attention_masks:
self.generate_full_masks()
self.inputs = []
self.targets = []
self.criterion = nn.MSELoss()
def name(self):
return "Transformer"
@staticmethod
def modify_commandline_options(parser, opt):
parser.add_argument('--dhid', type=int, default=512)
parser.add_argument('--nlayers', type=int, default=6)
parser.add_argument('--nhead', type=int, default=8)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--use_pos_emb_output', action='store_true', help="whether to use positional embeddings for output modality transformers")
parser.add_argument('--use_rotary_pos_emb', action='store_true', help="whether to use rotary position embeddings")
parser.add_argument('--use_x_transformers', action='store_true', help="whether to use rotary position embeddings")
# parser.add_argument('--generate_attention_masks', action='store_true', help="whether to generate the masks (but right now they are full masks, so it's not necessary")
return parser
def generate_full_masks(self):
input_mods = self.input_mods
output_mods = self.output_mods
input_lengths = self.input_lengths
self.src_masks = []
for i, mod in enumerate(input_mods):
mask = torch.zeros(input_lengths[i],input_lengths[i])
self.register_buffer('src_mask_'+str(i), mask)
self.src_masks.append(mask)
self.output_masks = []
for i, mod in enumerate(output_mods):
mask = torch.zeros(sum(input_lengths),sum(input_lengths))
self.register_buffer('out_mask_'+str(i), mask)
self.output_masks.append(mask)
def forward(self, data):
# in lightning, forward defines the prediction/inference actions
latents = []
for i, mod in enumerate(self.input_mods):
latents.append(self.input_mod_nets[i].forward(data[i]))
latent = torch.cat(latents)
outputs = []
for i, mod in enumerate(self.output_mods):
output = self.output_mod_nets[i].forward(latent)[:self.output_lengths[i]]
outputs.append(output)
#import pdb;pdb.set_trace()
return outputs
def training_step(self, batch, batch_idx):
self.set_inputs(batch)
#print(self.inputs)
latents = []
for i, mod in enumerate(self.input_mods):
latents.append(self.input_mod_nets[i].forward(self.inputs[i]))
latent = torch.cat(latents)
loss_mse = 0
for i, mod in enumerate(self.output_mods):
output = self.output_mod_nets[i].forward(latent)[:self.output_lengths[i]]
#print(output)
loss_mse += self.criterion(output, self.targets[i])
#loss_mse += self.criterion(output, self.targets[i]).detach()
#print(loss_mse)
#if self.opt.precision == 16:
# loss_mse *= 100 # loss scaling
self.log('mse_loss', loss_mse)
return loss_mse
#return torch.tensor(0.0, dtype=torch.float32, requires_grad=True)
#def configure_optimizers(self):
# print("HIIIIIIIIIIIIIIIIII")
# optimizer = torch.optim.Adam(self.parameters(), lr=self.opt.learning_rate)
# return [optimizer]
#def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
# optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
# optimizer.zero_grad()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment