Last active
July 15, 2020 00:58
-
-
Save amogh112/deecd6d388c43cc58c33ed9c659b4b4d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
from collections import defaultdict | |
import os | |
import gc | |
import torch | |
from tqdm import trange, tqdm | |
import masker | |
import tests | |
import utils | |
from losses import Losses | |
from masker import Masker, TestMasker | |
amp = None | |
class TaskTuner: | |
""" Class implementing the trainer for the project """ | |
def __init__(self, model, optimizer, train_loader, test_loader, args, epoch=-1, global_step=0, test_mode=False,debug=False): | |
if args.fp16: | |
try: | |
from apex import amp | |
global amp | |
amp = amp | |
except ImportError: | |
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") | |
self.model = model # Includes the finetuning part | |
self.args = args | |
self.optimizer = optimizer | |
self.train_loader = train_loader | |
self.test_loader = test_loader | |
self.epoch = epoch | |
self.module = model.module if hasattr(model, 'module') else model # for data parallel | |
self.global_step = global_step | |
self.debug = debug | |
self.task = args.task | |
def train(self): | |
best_eval = 0 | |
try: | |
for epoch in trange(self.epoch + 1, self.args.num_train_epochs, desc='Training model'): | |
if self.args.local_rank != -1: | |
self.train_loader.sampler.set_epoch(epoch) | |
# Train | |
self.run_epoch(epoch) | |
print("Train epoch done") | |
# Validate | |
val_score = self.run_epoch(epoch, train=False) | |
print("Val epoch done") | |
# TODO - task specific - save the best model. | |
is_best = val_score > best_eval | |
best_eval = max(val_score, best_eval) | |
if self.args.local_rank <= 0 and not self.args.debug: | |
print('Saving checkpoint, is best:', is_best) | |
utils.save_checkpoint(self.model, self.optimizer, self.train_loader.dataset.tokenizer, is_best, | |
epoch, self.args.checkpoint_dir, amp=amp, global_step=self.global_step, | |
args=self.args) | |
except KeyboardInterrupt: | |
if self.args.local_rank <= 0: print(f'You decided to finish the training at epoch {epoch}') | |
def run_epoch(self, epoch, train=True): | |
torch.cuda.synchronize() | |
# Initialize meters | |
avg_batch_time = utils.AverageMeter() | |
avg_data_time = utils.AverageMeter() | |
list_losses = ['total', 'binary_classification'] | |
average_meters = defaultdict(lambda: utils.AverageMeter()) | |
if not train: | |
avg_acc = utils.AverageMeter() | |
# Switch to train mode | |
if train: | |
self.model.train() | |
else: | |
self.model.eval() | |
end = time.time() | |
with torch.set_grad_enabled(train), \ | |
tqdm(self.train_loader if train else self.test_loader, | |
desc=f'Training epoch {epoch}' if train else f'Validating {f"epoch {epoch}" if epoch else ""}', | |
disable=self.args.local_rank > 0) as t: | |
for batch_idx, data in enumerate(t): | |
# Measure data loading time | |
avg_data_time.update(time.time() - end) | |
# -------------- Organize inputs ------------- # | |
img_no_mask_locs = None | |
text_no_mask_locs = None | |
text_mask_locs = None | |
max_txt_seq_len = self.args.max_txt_seq_len | |
imgs = data["imgs"].cuda() # [4, 12, 3, 112, 112] batch size 4, 12 bboxes found | |
img_bboxes = data['img_bboxes'].cuda() # [4, 12, 4] batch size 4, 12 boxes. | |
imgs_len = data['imgs_len'].cuda() # No of images in each batch | |
sep_token_id = data['sep_token_id'].cuda() | |
# Since images will be passed for each answer pair, | |
imgs = torch.repeat_interleave(imgs,repeats=4,dim=0) | |
img_bboxes= torch.repeat_interleave(img_bboxes,repeats=4,dim=0) | |
imgs_len = torch.repeat_interleave(imgs_len,repeats=4,dim=0) | |
question = data["question"].cuda() # [4,14] 4 is the batch size | |
question_len = data["question_len"].cuda() | |
answer_choices = data["answer_choices"].cuda() # [4,4,64] 4 batch size, 4 options, 64 | |
answer_len = data["answer_len"].cuda() | |
dim_answers = answer_choices.shape[-1] | |
# Stack questions and answers | |
questions = torch.repeat_interleave(question,repeats=4,dim=0) # [16, 14]) when 4 is the batch size | |
answers = answer_choices.view(-1, dim_answers) # torch.Size([16, 64]) | |
separator = torch.repeat_interleave(sep_token_id,repeats=4,dim=0).view(-1,1) | |
num_inputs = len(questions) | |
answer_labels = data["answer_label"].cuda() | |
answer_labels = torch.nn.functional.one_hot(answer_labels.squeeze(),4) # (4,4) converted integer labels to one hot | |
answer_labels = answer_labels.flatten().float() #(16,1) | |
if self.task == "Q2A": | |
text = torch.cat([questions, separator, answers],dim=1) | |
text = text[:, :max_txt_seq_len] | |
text_len = (question_len.view(-1,1) + answer_len).flatten() + 1 # Adding 1 for SEP. | |
elif self.task == "QA2R": | |
# correct_answers = | |
dim_rationale = rationale_choices.shape[-1] | |
rationale = rationale_choices.view(-1, dim_rationale) # torch.Size([16, 64]) | |
text = torch.cat([questions, separator, correct_answer,separator,rationale],dim=1) | |
text = text[:, :max_txt_seq_len] | |
rationale_labels = data["rationale_label"].cuda() | |
# else: | |
img_attn_mask = \ | |
torch.arange(self.args.max_img_seq_len, device=imgs.device)[None, :] < imgs_len[:, None] | |
text_attn_mask = \ | |
torch.arange(self.args.max_txt_seq_len, device=imgs.device)[None, :] < text_len[:, None] | |
attn_mask = torch.cat((text_attn_mask[:, :1], img_attn_mask, text_attn_mask[:, 1:]), dim=1) | |
# text starts with [IMG] token that gets moved to beginning of input in forward pass | |
# -------------- Forward pass ---------------- # | |
img_lens=None | |
txt_lens=None | |
img_locs=None | |
txt_locs=None | |
# breakpoint() | |
# Get the logits from classifier | |
outputs = self.model(imgs, text, img_bboxes, attention_mask=attn_mask, img_lens=imgs_len,txt_lens=text_len, img_locs=img_locs, txt_locs=txt_locs) # [16,1] | |
# -------------- Compute losses -------------- # | |
loss_values = {} | |
loss_values['binary_classification'] = torch.nn.BCEWithLogitsLoss()(outputs, answer_labels) | |
# Combine losses if more than one | |
loss = loss_values['binary_classification'] | |
if self.args.n_gpu > 1: | |
loss = loss.mean() | |
loss_values['total'] = loss | |
# --------------- Update model -------------- # | |
if train: | |
if self.args.fp16: | |
with amp.scale_loss(loss, self.optimizer) as scaled_loss: | |
scaled_loss.backward() | |
else: | |
(loss / self.args.gradient_accumulation_steps).backward() | |
if (batch_idx + 1) % self.args.gradient_accumulation_steps == 0: | |
for loss_name in list_losses: # Record losses | |
average_meters[loss_name].update(loss_values[loss_name].detach().item() / | |
self.args.gradient_accumulation_steps, imgs.size(0)) | |
if train: | |
if self.args.fp16: | |
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm) | |
else: | |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) | |
self.optimizer.step() | |
self.model.zero_grad() | |
# Measure elapsed time | |
avg_batch_time.update(time.time() - end) | |
end = time.time() | |
# ------------- Show information ------------ # | |
postfix_kwargs = {} | |
if not train: | |
#TODO- Write generic code to have accuracy metrics at the start | |
results = tests.accuracy_classify(outputs, answer_labels) | |
avg_acc.update(*(results['acc'])) | |
postfix_kwargs['acc'] = avg_acc.avg | |
for loss_name in list_losses: | |
postfix_kwargs[loss_name] = average_meters[loss_name].avg | |
t.set_postfix( | |
DataTime=avg_data_time.avg, | |
BatchTime=avg_batch_time.avg, | |
**postfix_kwargs | |
) | |
if train: | |
if self.global_step % self.args.print_freq == 0 and self.args.writer: | |
self.args.writer.add_scalars('Train/loss', {**postfix_kwargs}, | |
self.global_step * self.args.train_batch_size * self.args.step_n_gpus) | |
self.global_step += 1 | |
# if batch_idx%50 == 0: | |
# gc.collect() | |
# del loss_values, loss,outputs,data,imgs | |
if not train: | |
cnt = average_meters['total'].count | |
if epoch is not None: | |
loss_scalars = {} | |
for loss_name in list_losses: | |
loss_scalars[loss_name] = utils.gather_score(average_meters[loss_name].avg, cnt).item() | |
acc_scalars = { | |
'acc' : utils.gather_score(avg_acc.avg, cnt).item() | |
} | |
if self.args.writer: | |
self.args.writer.add_scalars('val/loss', loss_scalars, epoch) | |
self.args.writer.add_scalars('val/acc', acc_scalars, epoch) | |
return utils.gather_score(avg_acc.avg, cnt).item() | |
class VLBertClassifier(VLBert): | |
def __init__(self, cfg, args, tok, num_layers, num_outputs, hidden_units=1024, dim_mlp=384): | |
super(VLBertClassifier, self).__init__(cfg, args, tok) | |
if num_layers == 2: | |
self.final_mlp = torch.nn.Sequential( | |
torch.nn.Dropout(0.1, inplace=False), | |
torch.nn.Linear(dim_mlp, hidden_units), | |
torch.nn.ReLU(inplace=True), | |
torch.nn.Dropout(0.1, inplace=False), | |
torch.nn.Linear(hidden_units, num_outputs), | |
) | |
elif num_layers == "1": | |
self.final_mlp = torch.nn.Sequential( | |
torch.nn.Dropout(0.1, inplace=False), | |
torch.nn.Linear(dim_mlp, 1) | |
) | |
# Initialise the weights for MLP | |
for m in self.final_mlp.modules(): | |
if isinstance(m, torch.nn.Linear): | |
torch.nn.init.xavier_uniform_(m.weight) | |
torch.nn.init.constant_(m.bias, 0) | |
def forward(self, imgs, text, img_bboxes, attention_mask, img_lens, | |
txt_lens, img_locs, txt_locs): | |
lm_preds, vm_preds, input_pointing_pred, hidden_states, *_ = \ | |
super(VLBertClassifier, self).forward(imgs, text, img_bboxes, attention_mask=attention_mask, img_lens=img_lens, | |
txt_lens=txt_lens, img_locs=img_locs, txt_locs=txt_locs) | |
txt_token_embedding = hidden_states[:, 12] # Get the 12th embeddings from hidden_state (torch.Size([16, 384])) | |
output = self.final_mlp(txt_token_embedding).squeeze() | |
return output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment