Skip to content

Instantly share code, notes, and snippets.

@MichelNivard
Created March 15, 2023 08:02
Show Gist options
  • Save MichelNivard/bb16969446d84826a5f13efb07f36154 to your computer and use it in GitHub Desktop.
Save MichelNivard/bb16969446d84826a5f13efb07f36154 to your computer and use it in GitHub Desktop.
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments
# Set the path to the text file to fine-tune on
path_to_file = "path/to/text/file.txt"
# Load the tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
# Load the text dataset and collator
dataset = TextDataset(
tokenizer=tokenizer,
file_path=path_to_file,
block_size=128,
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
)
# Define the training arguments
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=1,
per_device_train_batch_size=16,
per_device_eval_batch_size=32,
logging_steps=5000,
save_steps=10000,
evaluation_strategy='steps',
eval_steps=10000,
save_total_limit=2,
learning_rate=5e-5,
warmup_steps=5000,
fp16=True,
)
# Define the trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=data_collator,
)
# Fine-tune the model
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment