Skip to content

Instantly share code, notes, and snippets.

@amogh112
Last active July 15, 2020 00:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save amogh112/deecd6d388c43cc58c33ed9c659b4b4d to your computer and use it in GitHub Desktop.
Save amogh112/deecd6d388c43cc58c33ed9c659b4b4d to your computer and use it in GitHub Desktop.
import time
from collections import defaultdict
import os
import gc
import torch
from tqdm import trange, tqdm
import masker
import tests
import utils
from losses import Losses
from masker import Masker, TestMasker
amp = None
class TaskTuner:
""" Class implementing the trainer for the project """
def __init__(self, model, optimizer, train_loader, test_loader, args, epoch=-1, global_step=0, test_mode=False,debug=False):
if args.fp16:
try:
from apex import amp
global amp
amp = amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
self.model = model # Includes the finetuning part
self.args = args
self.optimizer = optimizer
self.train_loader = train_loader
self.test_loader = test_loader
self.epoch = epoch
self.module = model.module if hasattr(model, 'module') else model # for data parallel
self.global_step = global_step
self.debug = debug
self.task = args.task
def train(self):
best_eval = 0
try:
for epoch in trange(self.epoch + 1, self.args.num_train_epochs, desc='Training model'):
if self.args.local_rank != -1:
self.train_loader.sampler.set_epoch(epoch)
# Train
self.run_epoch(epoch)
print("Train epoch done")
# Validate
val_score = self.run_epoch(epoch, train=False)
print("Val epoch done")
# TODO - task specific - save the best model.
is_best = val_score > best_eval
best_eval = max(val_score, best_eval)
if self.args.local_rank <= 0 and not self.args.debug:
print('Saving checkpoint, is best:', is_best)
utils.save_checkpoint(self.model, self.optimizer, self.train_loader.dataset.tokenizer, is_best,
epoch, self.args.checkpoint_dir, amp=amp, global_step=self.global_step,
args=self.args)
except KeyboardInterrupt:
if self.args.local_rank <= 0: print(f'You decided to finish the training at epoch {epoch}')
def run_epoch(self, epoch, train=True):
torch.cuda.synchronize()
# Initialize meters
avg_batch_time = utils.AverageMeter()
avg_data_time = utils.AverageMeter()
list_losses = ['total', 'binary_classification']
average_meters = defaultdict(lambda: utils.AverageMeter())
if not train:
avg_acc = utils.AverageMeter()
# Switch to train mode
if train:
self.model.train()
else:
self.model.eval()
end = time.time()
with torch.set_grad_enabled(train), \
tqdm(self.train_loader if train else self.test_loader,
desc=f'Training epoch {epoch}' if train else f'Validating {f"epoch {epoch}" if epoch else ""}',
disable=self.args.local_rank > 0) as t:
for batch_idx, data in enumerate(t):
# Measure data loading time
avg_data_time.update(time.time() - end)
# -------------- Organize inputs ------------- #
img_no_mask_locs = None
text_no_mask_locs = None
text_mask_locs = None
max_txt_seq_len = self.args.max_txt_seq_len
imgs = data["imgs"].cuda() # [4, 12, 3, 112, 112] batch size 4, 12 bboxes found
img_bboxes = data['img_bboxes'].cuda() # [4, 12, 4] batch size 4, 12 boxes.
imgs_len = data['imgs_len'].cuda() # No of images in each batch
sep_token_id = data['sep_token_id'].cuda()
# Since images will be passed for each answer pair,
imgs = torch.repeat_interleave(imgs,repeats=4,dim=0)
img_bboxes= torch.repeat_interleave(img_bboxes,repeats=4,dim=0)
imgs_len = torch.repeat_interleave(imgs_len,repeats=4,dim=0)
question = data["question"].cuda() # [4,14] 4 is the batch size
question_len = data["question_len"].cuda()
answer_choices = data["answer_choices"].cuda() # [4,4,64] 4 batch size, 4 options, 64
answer_len = data["answer_len"].cuda()
dim_answers = answer_choices.shape[-1]
# Stack questions and answers
questions = torch.repeat_interleave(question,repeats=4,dim=0) # [16, 14]) when 4 is the batch size
answers = answer_choices.view(-1, dim_answers) # torch.Size([16, 64])
separator = torch.repeat_interleave(sep_token_id,repeats=4,dim=0).view(-1,1)
num_inputs = len(questions)
answer_labels = data["answer_label"].cuda()
answer_labels = torch.nn.functional.one_hot(answer_labels.squeeze(),4) # (4,4) converted integer labels to one hot
answer_labels = answer_labels.flatten().float() #(16,1)
if self.task == "Q2A":
text = torch.cat([questions, separator, answers],dim=1)
text = text[:, :max_txt_seq_len]
text_len = (question_len.view(-1,1) + answer_len).flatten() + 1 # Adding 1 for SEP.
elif self.task == "QA2R":
# correct_answers =
dim_rationale = rationale_choices.shape[-1]
rationale = rationale_choices.view(-1, dim_rationale) # torch.Size([16, 64])
text = torch.cat([questions, separator, correct_answer,separator,rationale],dim=1)
text = text[:, :max_txt_seq_len]
rationale_labels = data["rationale_label"].cuda()
# else:
img_attn_mask = \
torch.arange(self.args.max_img_seq_len, device=imgs.device)[None, :] < imgs_len[:, None]
text_attn_mask = \
torch.arange(self.args.max_txt_seq_len, device=imgs.device)[None, :] < text_len[:, None]
attn_mask = torch.cat((text_attn_mask[:, :1], img_attn_mask, text_attn_mask[:, 1:]), dim=1)
# text starts with [IMG] token that gets moved to beginning of input in forward pass
# -------------- Forward pass ---------------- #
img_lens=None
txt_lens=None
img_locs=None
txt_locs=None
# breakpoint()
# Get the logits from classifier
outputs = self.model(imgs, text, img_bboxes, attention_mask=attn_mask, img_lens=imgs_len,txt_lens=text_len, img_locs=img_locs, txt_locs=txt_locs) # [16,1]
# -------------- Compute losses -------------- #
loss_values = {}
loss_values['binary_classification'] = torch.nn.BCEWithLogitsLoss()(outputs, answer_labels)
# Combine losses if more than one
loss = loss_values['binary_classification']
if self.args.n_gpu > 1:
loss = loss.mean()
loss_values['total'] = loss
# --------------- Update model -------------- #
if train:
if self.args.fp16:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
(loss / self.args.gradient_accumulation_steps).backward()
if (batch_idx + 1) % self.args.gradient_accumulation_steps == 0:
for loss_name in list_losses: # Record losses
average_meters[loss_name].update(loss_values[loss_name].detach().item() /
self.args.gradient_accumulation_steps, imgs.size(0))
if train:
if self.args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
self.optimizer.step()
self.model.zero_grad()
# Measure elapsed time
avg_batch_time.update(time.time() - end)
end = time.time()
# ------------- Show information ------------ #
postfix_kwargs = {}
if not train:
#TODO- Write generic code to have accuracy metrics at the start
results = tests.accuracy_classify(outputs, answer_labels)
avg_acc.update(*(results['acc']))
postfix_kwargs['acc'] = avg_acc.avg
for loss_name in list_losses:
postfix_kwargs[loss_name] = average_meters[loss_name].avg
t.set_postfix(
DataTime=avg_data_time.avg,
BatchTime=avg_batch_time.avg,
**postfix_kwargs
)
if train:
if self.global_step % self.args.print_freq == 0 and self.args.writer:
self.args.writer.add_scalars('Train/loss', {**postfix_kwargs},
self.global_step * self.args.train_batch_size * self.args.step_n_gpus)
self.global_step += 1
# if batch_idx%50 == 0:
# gc.collect()
# del loss_values, loss,outputs,data,imgs
if not train:
cnt = average_meters['total'].count
if epoch is not None:
loss_scalars = {}
for loss_name in list_losses:
loss_scalars[loss_name] = utils.gather_score(average_meters[loss_name].avg, cnt).item()
acc_scalars = {
'acc' : utils.gather_score(avg_acc.avg, cnt).item()
}
if self.args.writer:
self.args.writer.add_scalars('val/loss', loss_scalars, epoch)
self.args.writer.add_scalars('val/acc', acc_scalars, epoch)
return utils.gather_score(avg_acc.avg, cnt).item()
class VLBertClassifier(VLBert):
def __init__(self, cfg, args, tok, num_layers, num_outputs, hidden_units=1024, dim_mlp=384):
super(VLBertClassifier, self).__init__(cfg, args, tok)
if num_layers == 2:
self.final_mlp = torch.nn.Sequential(
torch.nn.Dropout(0.1, inplace=False),
torch.nn.Linear(dim_mlp, hidden_units),
torch.nn.ReLU(inplace=True),
torch.nn.Dropout(0.1, inplace=False),
torch.nn.Linear(hidden_units, num_outputs),
)
elif num_layers == "1":
self.final_mlp = torch.nn.Sequential(
torch.nn.Dropout(0.1, inplace=False),
torch.nn.Linear(dim_mlp, 1)
)
# Initialise the weights for MLP
for m in self.final_mlp.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
torch.nn.init.constant_(m.bias, 0)
def forward(self, imgs, text, img_bboxes, attention_mask, img_lens,
txt_lens, img_locs, txt_locs):
lm_preds, vm_preds, input_pointing_pred, hidden_states, *_ = \
super(VLBertClassifier, self).forward(imgs, text, img_bboxes, attention_mask=attention_mask, img_lens=img_lens,
txt_lens=txt_lens, img_locs=img_locs, txt_locs=txt_locs)
txt_token_embedding = hidden_states[:, 12] # Get the 12th embeddings from hidden_state (torch.Size([16, 384]))
output = self.final_mlp(txt_token_embedding).squeeze()
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment