Skip to content

Instantly share code, notes, and snippets.

@Narsil
Created October 11, 2021 14:32
Show Gist options
  • Save Narsil/357519fd385d864bfec3caf5aa8df575 to your computer and use it in GitHub Desktop.
Save Narsil/357519fd385d864bfec3caf5aa8df575 to your computer and use it in GitHub Desktop.
from transformers import pipeline
from torch.utils.data import Dataset
import tqdm
pipe = pipeline("text-classification", device=0)
class MyDataset(Dataset):
def __len__(self):
return 1000
def __getitem__(self, i):
if i % 64 == 0:
n = 100
else:
n = 1
return "This is a test" * n
dataset = MyDataset()
print("-" * 30)
print("Streaming no batching")
for out in tqdm.tqdm(pipe(dataset)):
pass
print("-" * 30)
print("Streaming batch_size=8")
for out in tqdm.tqdm(pipe(dataset, batch_size=8), total=len(dataset)):
pass
print("-" * 30)
print("Streaming batch_size=64")
for out in tqdm.tqdm(pipe(dataset, batch_size=64), total=len(dataset)):
pass
print("-" * 30)
print("Streaming batch_size=256")
for out in tqdm.tqdm(pipe(dataset, batch_size=256), total=len(dataset)):
pass
print("-" * 30)
print("Streaming batch_size=512")
for out in tqdm.tqdm(pipe(dataset, batch_size=512), total=len(dataset)):
pass
print("-" * 30)
print("Streaming batch_size=1024")
for out in tqdm.tqdm(pipe(dataset, batch_size=1024), total=len(dataset)):
pass
~
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment