Created
December 4, 2021 00:39
-
-
Save HandcartCactus/132a64b0dd9f2ed037ef876d32e6a834 to your computer and use it in GitHub Desktop.
FakeOnion train script
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import GPT2Tokenizer, GPT2LMHeadModel,TrainingArguments, Trainer | |
import torch | |
from torch.utils.data import Dataset | |
from sklearn.model_selection import train_test_split | |
device = 'cpu' | |
if torch.cuda.is_available(): | |
device = 'cuda' | |
print(device) | |
MODEL_SOURCE = 'distilgpt2' | |
# Add a padding token to the tokenizer for training | |
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_SOURCE) | |
tokenizer.add_special_tokens({'pad_token':tokenizer.eos_token}) | |
model = GPT2LMHeadModel.from_pretrained(MODEL_SOURCE) | |
model = model.to(device) | |
# define a text dataset | |
class GptTextDataset(Dataset): | |
def __init__(self, data, tokenizer, max_length=256): | |
""" | |
A subclass of torch.utils.data.Dataset for text generation on raw text. | |
data: sequence of strings | |
tokenizer: your GPT2Tokenizer | |
max_length: int, token length to pad to, Default 280. | |
""" | |
self.data = data | |
self.tokenizer = tokenizer | |
self.max_length = max_length | |
def __len__(self): | |
return len(self.data) | |
def _prep_string(self, s): | |
""" | |
Adds the eos token and encodes the text with a padding up to max_length | |
""" | |
plus_eos = s + self.tokenizer.eos_token | |
return self.tokenizer.encode(plus_eos, return_tensors='pt', padding='max_length', | |
max_length=self.max_length, truncation=True) | |
def _format(self, input_ids): | |
""" | |
formats the encoded text into a dict the training loop will accept. | |
'labels' are the same as input_ids because it's required for loss. | |
""" | |
return {'input_ids': input_ids, 'labels': input_ids} | |
def __getitem__(self, idx): | |
text = self.data[idx] | |
if text[-1] == '\n': | |
text = text[:-1] | |
input_ids = self._prep_string(text) | |
return self._format(input_ids) | |
# create a dataset with a *random* train/eval split. | |
# Need to save split if training across multiple sessions or metrics may be invalid. | |
with open('all_onion_headlines.txt','r',encoding='utf-8',newline='\n') as f: | |
data = f.readlines() | |
train, eval_ = train_test_split(data, train_size=0.80) | |
train = GptTextDataset(train, tokenizer) | |
eval_ = GptTextDataset(eval_, tokenizer) | |
args = TrainingArguments( | |
output_dir='trained_models_hf', | |
overwrite_output_dir=False, | |
num_train_epochs=3, | |
per_device_train_batch_size=12, | |
per_device_eval_batch_size=8, | |
learning_rate = 1e-3, | |
do_train=True, | |
do_eval=True, | |
do_predict=True, | |
) | |
trainer = Trainer( | |
model=model, | |
args=args, | |
train_dataset=train, | |
eval_dataset=eval_, | |
) | |
torch.cuda.empty_cache() | |
trainer.train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment