Skip to content

Instantly share code, notes, and snippets.

@fxmarty
Created November 18, 2022 10:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fxmarty/c5582af8dae5771d029d5dee3433ec48 to your computer and use it in GitHub Desktop.
Save fxmarty/c5582af8dae5771d029d5dee3433ec48 to your computer and use it in GitHub Desktop.
Compare variable batch size vs fixed very long batch size
"""
A minimal script to compare inference with variable batch sizes vs a fixed batch size long enough to handle all cases.
Change `padding_style` to compare.
"""
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
data = load_dataset("glue", "sst2", split="validation")
padding_style = "max_length"
#padding_style = True
def preprocess_function(examples):
# Tokenize the texts
result = tokenizer(examples["sentence"], padding=padding_style, truncation=True, return_tensors="pt")
return result
device = "cuda:0"
total_samples = len(data)
model.eval()
model = model.to(device)
batch_size = 8
n_batches = 10
used_samples = batch_size * n_batches
assert used_samples <= total_samples
inp = data.shuffle().select(range(used_samples)).map(preprocess_function, remove_columns=["sentence", "idx", "label"], batched=True, batch_size=batch_size)
inp.set_format("torch")
n_pads = 0
n_elems = 0
sequence_length = 0
for single_input in inp["input_ids"]:
n_pads += (single_input == 0).sum().item()
n_elems += single_input.numel()
sequence_length += len(single_input)
sequence_length = sequence_length / used_samples
print(f"Padding percentage: {n_pads / n_elems * 100:.2f} %")
print(f"Sequence length: {sequence_length}")
print(f"Batch size: {batch_size}")
from torch.utils.data import DataLoader
dataloader = DataLoader(
inp, shuffle=False, batch_size=8
)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for i in range(n_batches):
for batch in dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
_ = model(**batch)
end_event.record()
torch.cuda.synchronize()
total_time = start_event.elapsed_time(end_event) / n_batches
print(f"Total time: {total_time:.2f} ms (per batch)")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment