Created
September 5, 2020 13:01
-
-
Save priya-dwivedi/ef8b6cfbb62a4fe3b24631d20e1d5afa to your computer and use it in GitHub Desktop.
Wikihow T5 Tuner
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class T5FineTuner(pl.LightningModule): | |
def __init__(self, hparams): | |
super(T5FineTuner, self).__init__() | |
self.hparams = hparams | |
self.model = T5ForConditionalGeneration.from_pretrained(hparams.model_name_or_path) | |
self.tokenizer = T5Tokenizer.from_pretrained(hparams.tokenizer_name_or_path) | |
self.rouge_metric = load_metric('rouge') | |
if self.hparams.freeze_embeds: | |
self.freeze_embeds() | |
if self.hparams.freeze_encoder: | |
self.freeze_params(self.model.get_encoder()) | |
assert_all_frozen(self.model.get_encoder()) | |
n_observations_per_split = { | |
"train": self.hparams.n_train, | |
"validation": self.hparams.n_val, | |
"test": self.hparams.n_test, | |
} | |
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()} | |
def freeze_params(self, model): | |
for par in model.parameters(): | |
par.requires_grad = False | |
def freeze_embeds(self): | |
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" | |
try: | |
self.freeze_params(self.model.model.shared) | |
for d in [self.model.model.encoder, self.model.model.decoder]: | |
freeze_params(d.embed_positions) | |
freeze_params(d.embed_tokens) | |
except AttributeError: | |
self.freeze_params(self.model.shared) | |
for d in [self.model.encoder, self.model.decoder]: | |
self.freeze_params(d.embed_tokens) | |
def lmap(self, f, x): | |
"""list(map(f, x))""" | |
return list(map(f, x)) | |
def is_logger(self): | |
return self.trainer.proc_rank <= 0 | |
def parse_score(self, result): | |
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()} | |
def forward( | |
self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, lm_labels=None | |
): | |
return self.model( | |
input_ids, | |
attention_mask=attention_mask, | |
decoder_input_ids=decoder_input_ids, | |
decoder_attention_mask=decoder_attention_mask, | |
lm_labels=lm_labels, | |
) | |
def _step(self, batch): | |
lm_labels = batch["target_ids"] | |
lm_labels[lm_labels[:, :] == self.tokenizer.pad_token_id] = -100 | |
outputs = self( | |
input_ids=batch["source_ids"], | |
attention_mask=batch["source_mask"], | |
lm_labels=lm_labels, | |
decoder_attention_mask=batch['target_mask'] | |
) | |
loss = outputs[0] | |
return loss | |
def ids_to_clean_text(self, generated_ids): | |
gen_text = self.tokenizer.batch_decode( | |
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True | |
) | |
return self.lmap(str.strip, gen_text) | |
def _generative_step(self, batch) : | |
t0 = time.time() | |
generated_ids = self.model.generate( | |
batch["source_ids"], | |
attention_mask=batch["source_mask"], | |
use_cache=True, | |
decoder_attention_mask=batch['target_mask'], | |
max_length=150, | |
num_beams=2, | |
repetition_penalty=2.5, | |
length_penalty=1.0, | |
early_stopping=True | |
) | |
preds = self.ids_to_clean_text(generated_ids) | |
target = self.ids_to_clean_text(batch["target_ids"]) | |
gen_time = (time.time() - t0) / batch["source_ids"].shape[0] | |
loss = self._step(batch) | |
base_metrics = {'val_loss': loss} | |
# rouge: Dict = self.calc_generative_metrics(preds, target) | |
summ_len = np.mean(self.lmap(len, generated_ids)) | |
base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target) | |
self.rouge_metric.add_batch(preds, target) | |
# rouge_results = self.rouge_metric.compute() | |
# rouge_dict = self.parse_score(rouge_results) | |
# base_metrics.update(rouge1=rouge_dict['rouge1'], rougeL=rouge_dict['rougeL']) | |
return base_metrics | |
def training_step(self, batch, batch_idx): | |
loss = self._step(batch) | |
tensorboard_logs = {"train_loss": loss} | |
return {"loss": loss, "log": tensorboard_logs} | |
def training_epoch_end(self, outputs): | |
avg_train_loss = torch.stack([x["loss"] for x in outputs]).mean() | |
tensorboard_logs = {"avg_train_loss": avg_train_loss} | |
return {"avg_train_loss": avg_train_loss, "log": tensorboard_logs, 'progress_bar': tensorboard_logs} | |
def validation_step(self, batch, batch_idx): | |
return self._generative_step(batch) | |
def validation_epoch_end(self, outputs): | |
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() | |
tensorboard_logs = {"val_loss": avg_loss} | |
rouge_results = self.rouge_metric.compute() | |
rouge_dict = self.parse_score(rouge_results) | |
tensorboard_logs.update(rouge1=rouge_dict['rouge1'], rougeL=rouge_dict['rougeL']) | |
## Clear out the lists for next epoch | |
self.target_gen= [] | |
self.prediction_gen=[] | |
return {"avg_val_loss": avg_loss, | |
"rouge1" : rouge_results['rouge1'], | |
"rougeL" : rouge_results['rougeL'], | |
"log": tensorboard_logs, 'progress_bar': tensorboard_logs} | |
def configure_optimizers(self): | |
"Prepare optimizer and schedule (linear warmup and decay)" | |
model = self.model | |
no_decay = ["bias", "LayerNorm.weight"] | |
optimizer_grouped_parameters = [ | |
{ | |
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], | |
"weight_decay": self.hparams.weight_decay, | |
}, | |
{ | |
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], | |
"weight_decay": 0.0, | |
}, | |
] | |
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) | |
self.opt = optimizer | |
return [optimizer] | |
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None, using_native_amp=False): | |
if self.trainer.use_tpu: | |
xm.optimizer_step(optimizer) | |
else: | |
optimizer.step() | |
optimizer.zero_grad() | |
self.lr_scheduler.step() | |
def get_tqdm_dict(self): | |
tqdm_dict = {"loss": "{:.3f}".format(self.trainer.avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]} | |
return tqdm_dict | |
def train_dataloader(self): | |
n_samples = self.n_obs['train'] | |
train_dataset = get_dataset(tokenizer=self.tokenizer, type_path="train", num_samples=n_samples, args=self.hparams) | |
dataloader = DataLoader(train_dataset, batch_size=self.hparams.train_batch_size, drop_last=True, shuffle=True, num_workers=4) | |
t_total = ( | |
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu))) | |
// self.hparams.gradient_accumulation_steps | |
* float(self.hparams.num_train_epochs) | |
) | |
scheduler = get_linear_schedule_with_warmup( | |
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total | |
) | |
self.lr_scheduler = scheduler | |
return dataloader | |
def val_dataloader(self): | |
n_samples = self.n_obs['validation'] | |
validation_dataset = get_dataset(tokenizer=self.tokenizer, type_path="validation", num_samples=n_samples, args=self.hparams) | |
return DataLoader(validation_dataset, batch_size=self.hparams.eval_batch_size, num_workers=4) | |
def test_dataloader(self): | |
n_samples = self.n_obs['test'] | |
test_dataset = get_dataset(tokenizer=self.tokenizer, type_path="test", num_samples=n_samples, args=self.hparams) | |
return DataLoader(test_dataset, batch_size=self.hparams.eval_batch_size, num_workers=4) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment