Skip to content

Instantly share code, notes, and snippets.

@marksverdhei
Last active April 5, 2022 00:36
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save marksverdhei/0a13f67e65460b71c05fcf558a6a91ae to your computer and use it in GitHub Desktop.
Save marksverdhei/0a13f67e65460b71c05fcf558a6a91ae to your computer and use it in GitHub Desktop.
Fine-tune T5 on word definition, using transformers Trainer API
from typing import Any, Dict
import numpy as np
import pandas as pd
from datasets import load_dataset
from torch.utils.data import Dataset
from transformers import (
set_seed,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
T5ForConditionalGeneration,
T5Tokenizer,
)
class TrainerDataset(Dataset):
"""
Torch dataset for trainer.
Needs to implement __getitem__ and __len__
"""
def __init__(self, df: pd.DataFrame, tokenizer: T5Tokenizer) -> None:
self.len = len(df)
prompts = list(df["Prompt"])
self.inputs = tokenizer(
prompts, padding=True, truncation=True, return_tensors="pt"
)
definition = list(df["Definition"])
self.label_ids = tokenizer(
definition, padding=True, truncation=True, return_tensors="pt"
).input_ids
def __getitem__(self, index) -> Dict[str, Any]:
"Trainer expects sample inputs with the following keys"
return {
"input_ids": self.inputs["input_ids"][index],
"attention_mask": self.inputs["attention_mask"][index],
"label_ids": self.label_ids[index],
}
def __len__(self) -> int:
return self.len
def make_prompt(row: pd.Series) -> str:
return f"define \"{row['Word']}\": {row['Example']}"
def main() -> None:
set_seed(42)
tokenizer = T5Tokenizer.from_pretrained("t5-base")
dataset = load_dataset("marksverdhei/wordnet-definitions-en-2021")
df_train = dataset["train"].to_pandas()
df_val = dataset["validation"].to_pandas()
# We generate input task prompts for T5
for df in df_train, df_val:
df["Prompt"] = df.apply(make_prompt, axis=1)
train_set = TrainerDataset(df_train, tokenizer)
val_set = TrainerDataset(df_val, tokenizer)
model = T5ForConditionalGeneration.from_pretrained("t5-base")
training_args = Seq2SeqTrainingArguments(
output_dir="./results/t5-base-define",
num_train_epochs=20,
per_device_eval_batch_size=32,
per_device_train_batch_size=4,
warmup_steps=500,
learning_rate=1e-4,
weight_decay=0.01,
logging_dir="./logs/t5-base-define-logs",
logging_steps=250,
save_strategy="epoch",
evaluation_strategy="steps",
eval_steps=500,
save_total_limit=3,
)
trainer = Seq2SeqTrainer(
model,
training_args,
train_dataset=train_set,
eval_dataset=val_set,
tokenizer=tokenizer,
)
trainer.evaluate()
trainer.train()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment