Last active
September 1, 2022 18:52
-
-
Save ilmarikyl/bb8cc18590e81f260e5dfc875f99bda7 to your computer and use it in GitHub Desktop.
Script for fine-tuning a multilingual T5 model (mT5-base) for Finnish QG
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch, os, json, argparse, sys | |
import numpy as np | |
import pandas as pd | |
from collections import Counter, defaultdict | |
import torch.nn.functional as F | |
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler | |
from transformers import MT5Tokenizer, MT5ForConditionalGeneration | |
from rich.table import Column, Table | |
from rich import box | |
from rich.console import Console | |
console = Console(record=True) | |
model_params = { | |
"MODEL": "google/mt5-base", # model_type: t5-base/t5-large | |
"TRAIN_BATCH_SIZE": 8, # training batch size | |
"VALID_BATCH_SIZE": 8, # validation batch size | |
"TRAIN_EPOCHS": 32, # number of training epochs | |
"VAL_EPOCHS": 1, # number of validation epochs | |
"LEARNING_RATE": 1e-4, # learning rate | |
"MAX_SOURCE_TEXT_LENGTH": 512, # max length of source text | |
"MAX_TARGET_TEXT_LENGTH": 50, # max length of target text | |
"SEED": 42, # set seed for reproducibility | |
} | |
# training logger to log training progress | |
training_logger = Table( | |
Column("Epoch", justify="center"), | |
Column("Steps", justify="center"), | |
Column("Loss", justify="center"), | |
title="Training Status", | |
pad_edge=False, | |
box=box.ASCII, | |
) | |
# Setting up the device for GPU usage | |
from torch import cuda | |
device = 'cuda' if cuda.is_available() else 'cpu' | |
class SQuADDataset(Dataset): | |
def __init__(self, passage_list, question_list, answer_list, tokenizer, src_max_length, tgt_max_length): | |
self.source_input_ids = [] | |
self.source_masks = [] | |
self.target_input_ids = [] | |
self.target_masks = [] | |
for passage, question, answer in zip(passage_list, question_list, answer_list): | |
answer_text = answer['text'] | |
answer_start = answer['answer_start'] | |
source_text = f'Luo kysymys: {passage[:answer_start]}[HL]{answer_text}[HL]{passage[len(answer_text) + answer_start:]}' | |
source_text_encodings_dict = tokenizer(source_text, truncation=True, max_length=src_max_length, padding="max_length") | |
target_text = question | |
target_text_encodings_dict = tokenizer(target_text, truncation=True, max_length=tgt_max_length, padding="max_length") | |
source_ids = source_text_encodings_dict["input_ids"] | |
source_mask = source_text_encodings_dict["attention_mask"] | |
target_ids = target_text_encodings_dict["input_ids"] | |
target_mask = target_text_encodings_dict["attention_mask"] | |
self.source_input_ids.append(torch.tensor(source_ids)) | |
self.source_masks.append(torch.tensor(source_mask)) | |
self.target_input_ids.append(torch.tensor(target_ids)) | |
self.target_masks.append(torch.tensor(target_mask)) | |
def __len__(self): | |
return len(self.source_input_ids) | |
def __getitem__(self, idx): | |
return { | |
"source_ids": self.source_input_ids[idx], | |
"source_mask": self.source_masks[idx], | |
"target_ids": self.target_input_ids[idx], | |
"target_ids_y": self.target_input_ids[idx], | |
} | |
def select_answer(answers): | |
''' | |
We select answers using the following rules: | |
1. voting | |
2. the shortest one. | |
''' | |
if len(answers) == 1: | |
return answers[0] | |
# Vote for the popular answer | |
start_pos: dict = defaultdict(list) | |
votes: Counter = Counter() | |
for ans_dict in answers: | |
answer_text = ans_dict["text"] | |
ans_char_start_pos = ans_dict["answer_start"] | |
start_pos[answer_text].append(ans_char_start_pos) | |
votes[answer_text] += 1 | |
# if we have agreement (i.e. # of votes != 1) | |
ans, n_vote = votes.most_common(1)[0] | |
if n_vote != 1: | |
return { | |
"text": ans, | |
"answer_start": start_pos[ans][0] | |
} | |
# if equal votes, select the shortest one | |
min_len = 9999 | |
idx = -1 | |
for i, ans_dict in enumerate(answers): | |
len_ = len(ans_dict["text"]) | |
if len_ > min_len: | |
idx = i | |
min_len = len_ | |
ret = { | |
"text": answers[idx]["text"], | |
"answer_start": answers[idx]["answer_start"] | |
} | |
return ret | |
def load_squad_dataset(path, tokenizer): | |
with open(path, "rb") as f: | |
squad_dict = json.load(f) | |
contexts, questions, answers = [], [], [] | |
for group in squad_dict["data"]: | |
for passage in group["paragraphs"]: | |
context = passage["context"] | |
for qa in passage["qas"]: | |
if qa["is_impossible"]: | |
continue | |
question = qa["question"] | |
answer = select_answer(qa["answers"]) | |
contexts.append(context) | |
questions.append(question) | |
answers.append(answer) | |
train_dataset = SQuADDataset(contexts, questions, answers, tokenizer, src_max_length=512, tgt_max_length=40) | |
return train_dataset | |
def train(epoch, tokenizer, model, device, loader, optimizer): | |
""" | |
Function to be called for training with the parameters passed from main function | |
""" | |
model.train() | |
for _, data in enumerate(loader, 0): | |
y = data["target_ids"].to(device, dtype=torch.long) | |
y_ids = y[:, :-1].contiguous() | |
lm_labels = y[:, 1:].clone().detach() | |
lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100 | |
ids = data["source_ids"].to(device, dtype=torch.long) | |
mask = data["source_mask"].to(device, dtype=torch.long) | |
outputs = model( | |
input_ids=ids, | |
attention_mask=mask, | |
decoder_input_ids=y_ids, | |
labels=lm_labels, | |
) | |
loss = outputs[0] | |
if _ % 10000 == 0: | |
training_logger.add_row(str(epoch), str(_), str(loss)) | |
console.print(training_logger) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
def validate(epoch, tokenizer, model, device, loader): | |
""" | |
Function to evaluate model for predictions | |
""" | |
model.eval() | |
predictions = [] | |
actuals = [] | |
with torch.no_grad(): | |
for _, data in enumerate(loader, 0): | |
y = data['target_ids'].to(device, dtype = torch.long) | |
ids = data['source_ids'].to(device, dtype = torch.long) | |
mask = data['source_mask'].to(device, dtype = torch.long) | |
generated_ids = model.generate( | |
input_ids = ids, | |
attention_mask = mask, | |
min_length=5, | |
max_length=150, | |
num_beams=2, | |
repetition_penalty=2.5, | |
length_penalty=1.0, | |
early_stopping=True | |
) | |
preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids] | |
target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True)for t in y] | |
if _%500==0: | |
console.print(f'Completed {_}') | |
predictions.extend(preds) | |
actuals.extend(target) | |
return predictions, actuals | |
def T5Trainer( | |
exp_name, model_params, output_dir="./outputs/" | |
): | |
# Set random seeds and deterministic pytorch for reproducibility | |
torch.manual_seed(model_params["SEED"]) # pytorch random seed | |
np.random.seed(model_params["SEED"]) # numpy random seed | |
torch.backends.cudnn.deterministic = True | |
# logging | |
console.log(f"""[Model]: Loading {model_params["MODEL"]}...\n""") | |
tokenizer = MT5Tokenizer.from_pretrained(model_params["MODEL"]) | |
tokenizer.add_special_tokens({'additional_special_tokens': ['[HL]']}) | |
model = MT5ForConditionalGeneration.from_pretrained(model_params["MODEL"]) | |
model.resize_token_embeddings(len(tokenizer)) | |
model = model.to(device) | |
# logging | |
console.log(f"[Data]: Reading data...\n") | |
training_set = load_squad_dataset("../../../datasets/qg_train_split-64604.json", tokenizer) | |
val_set = load_squad_dataset("../../../datasets/qg_dev_split-4902.json", tokenizer) | |
# Defining the parameters for creation of dataloaders | |
train_params = { | |
"batch_size": model_params["TRAIN_BATCH_SIZE"], | |
"shuffle": True, | |
"num_workers": 0, | |
} | |
val_params = { | |
"batch_size": model_params["VALID_BATCH_SIZE"], | |
"shuffle": False, | |
"num_workers": 0, | |
} | |
# Creation of Dataloaders for testing and validation. This will be used down for training and validation stage for the model. | |
training_loader = DataLoader(training_set, **train_params) | |
val_loader = DataLoader(val_set, **val_params) | |
# Defining the optimizer that will be used to tune the weights of the network in the training session. | |
optimizer = torch.optim.Adam( | |
params=model.parameters(), lr=model_params["LEARNING_RATE"] | |
) | |
# Training loop | |
console.log(f"[Initiating Fine Tuning]...\n") | |
for epoch in range(model_params["TRAIN_EPOCHS"]): | |
train(epoch, tokenizer, model, device, training_loader, optimizer) | |
if epoch+1 % 4 == 0 or epoch == 0: | |
# Saving the model after each epoch | |
console.log(f"[Saving Model after epoch {epoch+1}]...\n") | |
path = os.path.join(output_dir, f"epoch_{epoch+1}") | |
model.save_pretrained(path) | |
tokenizer.save_pretrained(path) | |
# evaluating test dataset | |
console.log(f"[Initiating Validation after epoch {epoch+1}]...\n") | |
for val_epoch in range(model_params["VAL_EPOCHS"]): | |
predictions, actuals = validate(val_epoch, tokenizer, model, device, val_loader) | |
final_df = pd.DataFrame({"Generated Text": predictions, "Actual Text": actuals}) | |
final_df.to_csv(os.path.join(output_dir, f"Epoch_{epoch+1}_predictions.csv")) | |
console.save_text(os.path.join(output_dir, f"{exp_name}_logs.txt")) | |
console.log(f"[Validation Completed.]\n") | |
console.print(f"""[Logs] Logs saved @ {os.path.join(output_dir,f"{exp_name}_logs.txt")}\n""") | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Placeholder') | |
parser.add_argument("--exp", required=True, | |
help="Name for experiment. Used in saving model as the output dir name.") | |
if len(sys.argv) == 1: | |
parser.print_help() | |
sys.exit(1) | |
args = parser.parse_args() | |
T5Trainer(exp_name=args.exp, model_params=model_params, output_dir=f'<PATH_HERE>/{args.exp}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I am trying to fine-tune mt5-base checkpoint for summarization task but the script doesn't run successfully, my script works well for mt5-small but it doesn't for mt5-base checkpoint. I've tried single A100 GPU 40 GB VRAM. Let me know which and how many GPU (s) you use for mt5-base fine-tuning. Thank you.
Script: https://colab.research.google.com/github/huggingface/notebooks/blob/master/course/chapter7/section5_tf.ipynb