Skip to content

Instantly share code, notes, and snippets.

@HandcartCactus
Created December 4, 2021 00:39
Show Gist options
  • Save HandcartCactus/132a64b0dd9f2ed037ef876d32e6a834 to your computer and use it in GitHub Desktop.
Save HandcartCactus/132a64b0dd9f2ed037ef876d32e6a834 to your computer and use it in GitHub Desktop.
FakeOnion train script
from transformers import GPT2Tokenizer, GPT2LMHeadModel,TrainingArguments, Trainer
import torch
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
print(device)
MODEL_SOURCE = 'distilgpt2'
# Add a padding token to the tokenizer for training
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_SOURCE)
tokenizer.add_special_tokens({'pad_token':tokenizer.eos_token})
model = GPT2LMHeadModel.from_pretrained(MODEL_SOURCE)
model = model.to(device)
# define a text dataset
class GptTextDataset(Dataset):
def __init__(self, data, tokenizer, max_length=256):
"""
A subclass of torch.utils.data.Dataset for text generation on raw text.
data: sequence of strings
tokenizer: your GPT2Tokenizer
max_length: int, token length to pad to, Default 280.
"""
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.data)
def _prep_string(self, s):
"""
Adds the eos token and encodes the text with a padding up to max_length
"""
plus_eos = s + self.tokenizer.eos_token
return self.tokenizer.encode(plus_eos, return_tensors='pt', padding='max_length',
max_length=self.max_length, truncation=True)
def _format(self, input_ids):
"""
formats the encoded text into a dict the training loop will accept.
'labels' are the same as input_ids because it's required for loss.
"""
return {'input_ids': input_ids, 'labels': input_ids}
def __getitem__(self, idx):
text = self.data[idx]
if text[-1] == '\n':
text = text[:-1]
input_ids = self._prep_string(text)
return self._format(input_ids)
# create a dataset with a *random* train/eval split.
# Need to save split if training across multiple sessions or metrics may be invalid.
with open('all_onion_headlines.txt','r',encoding='utf-8',newline='\n') as f:
data = f.readlines()
train, eval_ = train_test_split(data, train_size=0.80)
train = GptTextDataset(train, tokenizer)
eval_ = GptTextDataset(eval_, tokenizer)
args = TrainingArguments(
output_dir='trained_models_hf',
overwrite_output_dir=False,
num_train_epochs=3,
per_device_train_batch_size=12,
per_device_eval_batch_size=8,
learning_rate = 1e-3,
do_train=True,
do_eval=True,
do_predict=True,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train,
eval_dataset=eval_,
)
torch.cuda.empty_cache()
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment