Skip to content

Instantly share code, notes, and snippets.

@ilmarikyl
Last active September 1, 2022 18:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ilmarikyl/bb8cc18590e81f260e5dfc875f99bda7 to your computer and use it in GitHub Desktop.
Save ilmarikyl/bb8cc18590e81f260e5dfc875f99bda7 to your computer and use it in GitHub Desktop.
Script for fine-tuning a multilingual T5 model (mT5-base) for Finnish QG
import torch, os, json, argparse, sys
import numpy as np
import pandas as pd
from collections import Counter, defaultdict
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import MT5Tokenizer, MT5ForConditionalGeneration
from rich.table import Column, Table
from rich import box
from rich.console import Console
console = Console(record=True)
model_params = {
"MODEL": "google/mt5-base", # model_type: t5-base/t5-large
"TRAIN_BATCH_SIZE": 8, # training batch size
"VALID_BATCH_SIZE": 8, # validation batch size
"TRAIN_EPOCHS": 32, # number of training epochs
"VAL_EPOCHS": 1, # number of validation epochs
"LEARNING_RATE": 1e-4, # learning rate
"MAX_SOURCE_TEXT_LENGTH": 512, # max length of source text
"MAX_TARGET_TEXT_LENGTH": 50, # max length of target text
"SEED": 42, # set seed for reproducibility
}
# training logger to log training progress
training_logger = Table(
Column("Epoch", justify="center"),
Column("Steps", justify="center"),
Column("Loss", justify="center"),
title="Training Status",
pad_edge=False,
box=box.ASCII,
)
# Setting up the device for GPU usage
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
class SQuADDataset(Dataset):
def __init__(self, passage_list, question_list, answer_list, tokenizer, src_max_length, tgt_max_length):
self.source_input_ids = []
self.source_masks = []
self.target_input_ids = []
self.target_masks = []
for passage, question, answer in zip(passage_list, question_list, answer_list):
answer_text = answer['text']
answer_start = answer['answer_start']
source_text = f'Luo kysymys: {passage[:answer_start]}[HL]{answer_text}[HL]{passage[len(answer_text) + answer_start:]}'
source_text_encodings_dict = tokenizer(source_text, truncation=True, max_length=src_max_length, padding="max_length")
target_text = question
target_text_encodings_dict = tokenizer(target_text, truncation=True, max_length=tgt_max_length, padding="max_length")
source_ids = source_text_encodings_dict["input_ids"]
source_mask = source_text_encodings_dict["attention_mask"]
target_ids = target_text_encodings_dict["input_ids"]
target_mask = target_text_encodings_dict["attention_mask"]
self.source_input_ids.append(torch.tensor(source_ids))
self.source_masks.append(torch.tensor(source_mask))
self.target_input_ids.append(torch.tensor(target_ids))
self.target_masks.append(torch.tensor(target_mask))
def __len__(self):
return len(self.source_input_ids)
def __getitem__(self, idx):
return {
"source_ids": self.source_input_ids[idx],
"source_mask": self.source_masks[idx],
"target_ids": self.target_input_ids[idx],
"target_ids_y": self.target_input_ids[idx],
}
def select_answer(answers):
'''
We select answers using the following rules:
1. voting
2. the shortest one.
'''
if len(answers) == 1:
return answers[0]
# Vote for the popular answer
start_pos: dict = defaultdict(list)
votes: Counter = Counter()
for ans_dict in answers:
answer_text = ans_dict["text"]
ans_char_start_pos = ans_dict["answer_start"]
start_pos[answer_text].append(ans_char_start_pos)
votes[answer_text] += 1
# if we have agreement (i.e. # of votes != 1)
ans, n_vote = votes.most_common(1)[0]
if n_vote != 1:
return {
"text": ans,
"answer_start": start_pos[ans][0]
}
# if equal votes, select the shortest one
min_len = 9999
idx = -1
for i, ans_dict in enumerate(answers):
len_ = len(ans_dict["text"])
if len_ > min_len:
idx = i
min_len = len_
ret = {
"text": answers[idx]["text"],
"answer_start": answers[idx]["answer_start"]
}
return ret
def load_squad_dataset(path, tokenizer):
with open(path, "rb") as f:
squad_dict = json.load(f)
contexts, questions, answers = [], [], []
for group in squad_dict["data"]:
for passage in group["paragraphs"]:
context = passage["context"]
for qa in passage["qas"]:
if qa["is_impossible"]:
continue
question = qa["question"]
answer = select_answer(qa["answers"])
contexts.append(context)
questions.append(question)
answers.append(answer)
train_dataset = SQuADDataset(contexts, questions, answers, tokenizer, src_max_length=512, tgt_max_length=40)
return train_dataset
def train(epoch, tokenizer, model, device, loader, optimizer):
"""
Function to be called for training with the parameters passed from main function
"""
model.train()
for _, data in enumerate(loader, 0):
y = data["target_ids"].to(device, dtype=torch.long)
y_ids = y[:, :-1].contiguous()
lm_labels = y[:, 1:].clone().detach()
lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100
ids = data["source_ids"].to(device, dtype=torch.long)
mask = data["source_mask"].to(device, dtype=torch.long)
outputs = model(
input_ids=ids,
attention_mask=mask,
decoder_input_ids=y_ids,
labels=lm_labels,
)
loss = outputs[0]
if _ % 10000 == 0:
training_logger.add_row(str(epoch), str(_), str(loss))
console.print(training_logger)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def validate(epoch, tokenizer, model, device, loader):
"""
Function to evaluate model for predictions
"""
model.eval()
predictions = []
actuals = []
with torch.no_grad():
for _, data in enumerate(loader, 0):
y = data['target_ids'].to(device, dtype = torch.long)
ids = data['source_ids'].to(device, dtype = torch.long)
mask = data['source_mask'].to(device, dtype = torch.long)
generated_ids = model.generate(
input_ids = ids,
attention_mask = mask,
min_length=5,
max_length=150,
num_beams=2,
repetition_penalty=2.5,
length_penalty=1.0,
early_stopping=True
)
preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True)for t in y]
if _%500==0:
console.print(f'Completed {_}')
predictions.extend(preds)
actuals.extend(target)
return predictions, actuals
def T5Trainer(
exp_name, model_params, output_dir="./outputs/"
):
# Set random seeds and deterministic pytorch for reproducibility
torch.manual_seed(model_params["SEED"]) # pytorch random seed
np.random.seed(model_params["SEED"]) # numpy random seed
torch.backends.cudnn.deterministic = True
# logging
console.log(f"""[Model]: Loading {model_params["MODEL"]}...\n""")
tokenizer = MT5Tokenizer.from_pretrained(model_params["MODEL"])
tokenizer.add_special_tokens({'additional_special_tokens': ['[HL]']})
model = MT5ForConditionalGeneration.from_pretrained(model_params["MODEL"])
model.resize_token_embeddings(len(tokenizer))
model = model.to(device)
# logging
console.log(f"[Data]: Reading data...\n")
training_set = load_squad_dataset("../../../datasets/qg_train_split-64604.json", tokenizer)
val_set = load_squad_dataset("../../../datasets/qg_dev_split-4902.json", tokenizer)
# Defining the parameters for creation of dataloaders
train_params = {
"batch_size": model_params["TRAIN_BATCH_SIZE"],
"shuffle": True,
"num_workers": 0,
}
val_params = {
"batch_size": model_params["VALID_BATCH_SIZE"],
"shuffle": False,
"num_workers": 0,
}
# Creation of Dataloaders for testing and validation. This will be used down for training and validation stage for the model.
training_loader = DataLoader(training_set, **train_params)
val_loader = DataLoader(val_set, **val_params)
# Defining the optimizer that will be used to tune the weights of the network in the training session.
optimizer = torch.optim.Adam(
params=model.parameters(), lr=model_params["LEARNING_RATE"]
)
# Training loop
console.log(f"[Initiating Fine Tuning]...\n")
for epoch in range(model_params["TRAIN_EPOCHS"]):
train(epoch, tokenizer, model, device, training_loader, optimizer)
if epoch+1 % 4 == 0 or epoch == 0:
# Saving the model after each epoch
console.log(f"[Saving Model after epoch {epoch+1}]...\n")
path = os.path.join(output_dir, f"epoch_{epoch+1}")
model.save_pretrained(path)
tokenizer.save_pretrained(path)
# evaluating test dataset
console.log(f"[Initiating Validation after epoch {epoch+1}]...\n")
for val_epoch in range(model_params["VAL_EPOCHS"]):
predictions, actuals = validate(val_epoch, tokenizer, model, device, val_loader)
final_df = pd.DataFrame({"Generated Text": predictions, "Actual Text": actuals})
final_df.to_csv(os.path.join(output_dir, f"Epoch_{epoch+1}_predictions.csv"))
console.save_text(os.path.join(output_dir, f"{exp_name}_logs.txt"))
console.log(f"[Validation Completed.]\n")
console.print(f"""[Logs] Logs saved @ {os.path.join(output_dir,f"{exp_name}_logs.txt")}\n""")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Placeholder')
parser.add_argument("--exp", required=True,
help="Name for experiment. Used in saving model as the output dir name.")
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
args = parser.parse_args()
T5Trainer(exp_name=args.exp, model_params=model_params, output_dir=f'<PATH_HERE>/{args.exp}')
@pedramyamini
Copy link

pedramyamini commented Sep 1, 2022

I am trying to fine-tune mt5-base checkpoint for summarization task but the script doesn't run successfully, my script works well for mt5-small but it doesn't for mt5-base checkpoint. I've tried single A100 GPU 40 GB VRAM. Let me know which and how many GPU (s) you use for mt5-base fine-tuning. Thank you.

Script: https://colab.research.google.com/github/huggingface/notebooks/blob/master/course/chapter7/section5_tf.ipynb

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment