Created
August 20, 2021 12:32
-
-
Save gui11aume/b9acae4d9235a98b286fa104ef851ff9 to your computer and use it in GitHub Desktop.
Training script for custom MLM with 🤗 Transformers on a single GPU
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding:utf-8 -*- | |
''' | |
python 3.8.10 | |
pytorch 1.9.0 | |
transformers 4.10.0.dev0 | |
''' | |
import torch | |
import transformers | |
# Custom 'titles.txt' data set with one title per line. | |
from datasets import load_dataset | |
raw_datasets = load_dataset('text', data_files={'train': ['./titles.txt']}) | |
from transformers import AutoTokenizer | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") | |
def tokenize_function(examples): | |
return tokenizer(examples["text"], padding="max_length", truncation=True) | |
imdb = raw_datasets.map(tokenize_function, batched=True) | |
imdb = imdb.remove_columns(["text", "token_type_ids"]) | |
imdb.set_format("torch") | |
# Batch size 16 takes 19-20 GB on GPU. | |
from torch.utils.data import DataLoader | |
train_dataloader = DataLoader(imdb["train"], shuffle=True, batch_size=16) | |
from transformers import AutoModelForMaskedLM | |
model = AutoModelForMaskedLM.from_pretrained("bert-base-cased") | |
from transformers import AdamW | |
optimizer = AdamW(model.parameters(), lr=5e-5) | |
from transformers import get_scheduler | |
num_epochs = 2 | |
num_training_steps = num_epochs * len(train_dataloader) | |
lr_scheduler = get_scheduler( | |
"linear", | |
optimizer = optimizer, | |
num_warmup_steps = 1000, | |
num_training_steps = num_training_steps | |
) | |
device = "cuda" | |
model.to(device) | |
from tqdm.auto import tqdm | |
progress_bar = tqdm(range(num_training_steps)) | |
def apply_masking(batch, tokenizer): | |
# Mask 15% of the tokens at random. | |
batch['labels'] = batch['input_ids'].detach().clone() | |
mask_idx = torch.where(torch.rand(batch['labels'].shape) < .15) | |
batch['input_ids'][mask_idx] = tokenizer.mask_token_id | |
batch['input_ids'][batch['attention_mask'] == tokenizer.pad_token_id] = tokenizer.pad_token_id | |
model.train() | |
# Note: the loss is computed on all the tokens, not just the masked ones. | |
for epoch in range(num_epochs): | |
for batch in train_dataloader: | |
apply_masking(batch, tokenizer) | |
batch = {k: v.to(device) for k, v in batch.items()} | |
outputs = model(**batch) | |
loss = outputs.loss | |
loss.backward() | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
progress_bar.update(1) | |
model.save_pretrained("./") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment