Last active
April 5, 2022 00:36
-
-
Save marksverdhei/0a13f67e65460b71c05fcf558a6a91ae to your computer and use it in GitHub Desktop.
Fine-tune T5 on word definition, using transformers Trainer API
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any, Dict | |
import numpy as np | |
import pandas as pd | |
from datasets import load_dataset | |
from torch.utils.data import Dataset | |
from transformers import ( | |
set_seed, | |
Seq2SeqTrainer, | |
Seq2SeqTrainingArguments, | |
T5ForConditionalGeneration, | |
T5Tokenizer, | |
) | |
class TrainerDataset(Dataset): | |
""" | |
Torch dataset for trainer. | |
Needs to implement __getitem__ and __len__ | |
""" | |
def __init__(self, df: pd.DataFrame, tokenizer: T5Tokenizer) -> None: | |
self.len = len(df) | |
prompts = list(df["Prompt"]) | |
self.inputs = tokenizer( | |
prompts, padding=True, truncation=True, return_tensors="pt" | |
) | |
definition = list(df["Definition"]) | |
self.label_ids = tokenizer( | |
definition, padding=True, truncation=True, return_tensors="pt" | |
).input_ids | |
def __getitem__(self, index) -> Dict[str, Any]: | |
"Trainer expects sample inputs with the following keys" | |
return { | |
"input_ids": self.inputs["input_ids"][index], | |
"attention_mask": self.inputs["attention_mask"][index], | |
"label_ids": self.label_ids[index], | |
} | |
def __len__(self) -> int: | |
return self.len | |
def make_prompt(row: pd.Series) -> str: | |
return f"define \"{row['Word']}\": {row['Example']}" | |
def main() -> None: | |
set_seed(42) | |
tokenizer = T5Tokenizer.from_pretrained("t5-base") | |
dataset = load_dataset("marksverdhei/wordnet-definitions-en-2021") | |
df_train = dataset["train"].to_pandas() | |
df_val = dataset["validation"].to_pandas() | |
# We generate input task prompts for T5 | |
for df in df_train, df_val: | |
df["Prompt"] = df.apply(make_prompt, axis=1) | |
train_set = TrainerDataset(df_train, tokenizer) | |
val_set = TrainerDataset(df_val, tokenizer) | |
model = T5ForConditionalGeneration.from_pretrained("t5-base") | |
training_args = Seq2SeqTrainingArguments( | |
output_dir="./results/t5-base-define", | |
num_train_epochs=20, | |
per_device_eval_batch_size=32, | |
per_device_train_batch_size=4, | |
warmup_steps=500, | |
learning_rate=1e-4, | |
weight_decay=0.01, | |
logging_dir="./logs/t5-base-define-logs", | |
logging_steps=250, | |
save_strategy="epoch", | |
evaluation_strategy="steps", | |
eval_steps=500, | |
save_total_limit=3, | |
) | |
trainer = Seq2SeqTrainer( | |
model, | |
training_args, | |
train_dataset=train_set, | |
eval_dataset=val_set, | |
tokenizer=tokenizer, | |
) | |
trainer.evaluate() | |
trainer.train() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment