Skip to content

Instantly share code, notes, and snippets.

@SinclairCoder
Last active March 30, 2022 14:56
Show Gist options
  • Save SinclairCoder/677637621989983ed3cbb19e88dd0487 to your computer and use it in GitHub Desktop.
Save SinclairCoder/677637621989983ed3cbb19e88dd0487 to your computer and use it in GitHub Desktop.
uncomment the L109 to test different `max_input_len`
from pytorch_lightning import seed_everything
from transformers import AdamW, T5ForConditionalGeneration, T5Tokenizer, AutoConfig, BartTokenizer, BartForConditionalGeneration
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import os
class MyDataset(Dataset):
def __init__(self, tokenizer, raw_inputs, raw_targets, max_input_len=128, max_output_len=128):
self.max_input_len = max_input_len
self.max_output_len = max_output_len
self.tokenizer = tokenizer
self.inputs = []
self.targets = []
self._build_examples(raw_inputs, raw_targets)
def __len__(self):
return len(self.inputs)
def __getitem__(self, index):
source_ids = self.inputs[index]["input_ids"].squeeze()
target_ids = self.targets[index]["input_ids"].squeeze()
src_mask = self.inputs[index]["attention_mask"].squeeze() # might need to squeeze
target_mask = self.targets[index]["attention_mask"].squeeze() # might need to squeeze
return {"source_ids": source_ids, "source_mask": src_mask,
"target_ids": target_ids, "target_mask": target_mask}
def _build_examples(self, raw_inputs, raw_targets):
for i in range(len(raw_inputs)):
# change input and target to two strings
input = raw_inputs[i]
target = raw_targets[i]
# batch_encode_plus
tokenized_input = self.tokenizer(
[input], max_length=self.max_input_len, padding="max_length",
truncation=True, return_tensors="pt"
)
tokenized_target = self.tokenizer(
[target], max_length=self.max_output_len, padding="max_length",
truncation=True, return_tensors="pt"
)
self.inputs.append(tokenized_input)
self.targets.append(tokenized_target)
# toy dataset
raw_inputs = ["can't wait wait for my next visit.",
# "their sake list was extensive, but we were looking for purple haze, which wasn't listed but made for us upon request!", #
# "the spicy tuna roll was unusually good and the rock shrimp tempura was awesome, great appetizer to share!",
# "we love th pink pony."
]
raw_targets = ['restaurant general is great because it is NULL',
# "drinks style options is great because sake list is extensive [SSEP] service general is great because it is NULL",
# "food quality is great because spicy tuna roll is good [SSEP] food quality is great because rock shrimp tempura is awesome",
# "restaurant general is great because pink pony is love"
]
class MyT5FineTuner(pl.LightningModule):
"""
Fine tune a pre-trained T5 model
"""
def __init__(self, tfm_model, tokenizer):
super(MyT5FineTuner, self).__init__()
self.model = tfm_model
self.tokenizer = tokenizer
def forward(self, input_ids, attention_mask=None, decoder_input_ids=None,
decoder_attention_mask=None, labels=None):
return self.model(
input_ids,
attention_mask=attention_mask,
labels=labels,
)
def _step(self, batch):
lm_labels = batch["target_ids"]
lm_labels[lm_labels[:, :] == self.tokenizer.pad_token_id] = -100
outputs = self(
input_ids=batch["source_ids"],
attention_mask=batch["source_mask"],
labels=lm_labels,
)
loss = outputs[0]
return loss
def training_step(self, batch, batch_idx):
loss = self._step(batch)
print(loss)
return loss
def configure_optimizers(self):
""" Prepare optimizer and schedule (linear warmup and decay) """
optimizer = AdamW(self.model.parameters(), lr=3e-4, eps=1e-8)
return optimizer
def train_dataloader(self):
train_dataset = MyDataset(tokenizer, raw_inputs, raw_targets, max_input_len=300, max_output_len=128)
# train_dataset = MyDataset(tokenizer, raw_inputs, raw_targets, max_input_len=200, max_output_len=128)
dataloader = DataLoader(train_dataset, batch_size=1,
drop_last=False, shuffle=False, num_workers=1)
return dataloader
if __name__ == '__main__':
seed_everything(42)
tokenizer = T5Tokenizer.from_pretrained('t5-base')
tfm_model = T5ForConditionalGeneration.from_pretrained('t5-base')
model = MyT5FineTuner(tfm_model, tokenizer)
trainer = pl.Trainer()
trainer.fit(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment