Created
November 18, 2022 10:08
-
-
Save fxmarty/c5582af8dae5771d029d5dee3433ec48 to your computer and use it in GitHub Desktop.
Compare variable batch size vs fixed very long batch size
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A minimal script to compare inference with variable batch sizes vs a fixed batch size long enough to handle all cases. | |
Change `padding_style` to compare. | |
""" | |
import torch | |
from datasets import load_dataset | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") | |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") | |
data = load_dataset("glue", "sst2", split="validation") | |
padding_style = "max_length" | |
#padding_style = True | |
def preprocess_function(examples): | |
# Tokenize the texts | |
result = tokenizer(examples["sentence"], padding=padding_style, truncation=True, return_tensors="pt") | |
return result | |
device = "cuda:0" | |
total_samples = len(data) | |
model.eval() | |
model = model.to(device) | |
batch_size = 8 | |
n_batches = 10 | |
used_samples = batch_size * n_batches | |
assert used_samples <= total_samples | |
inp = data.shuffle().select(range(used_samples)).map(preprocess_function, remove_columns=["sentence", "idx", "label"], batched=True, batch_size=batch_size) | |
inp.set_format("torch") | |
n_pads = 0 | |
n_elems = 0 | |
sequence_length = 0 | |
for single_input in inp["input_ids"]: | |
n_pads += (single_input == 0).sum().item() | |
n_elems += single_input.numel() | |
sequence_length += len(single_input) | |
sequence_length = sequence_length / used_samples | |
print(f"Padding percentage: {n_pads / n_elems * 100:.2f} %") | |
print(f"Sequence length: {sequence_length}") | |
print(f"Batch size: {batch_size}") | |
from torch.utils.data import DataLoader | |
dataloader = DataLoader( | |
inp, shuffle=False, batch_size=8 | |
) | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
start_event.record() | |
for i in range(n_batches): | |
for batch in dataloader: | |
batch = {k: v.to(device) for k, v in batch.items()} | |
with torch.no_grad(): | |
_ = model(**batch) | |
end_event.record() | |
torch.cuda.synchronize() | |
total_time = start_event.elapsed_time(end_event) / n_batches | |
print(f"Total time: {total_time:.2f} ms (per batch)") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment