Skip to content

Instantly share code, notes, and snippets.

@harpone
Last active December 3, 2020 13:22
Show Gist options
  • Save harpone/71faadbbcb9bbe20923de9306d4cf091 to your computer and use it in GitHub Desktop.
Save harpone/71faadbbcb9bbe20923de9306d4cf091 to your computer and use it in GitHub Desktop.
Testing/profiling webdataset data loading speed issue
from itertools import islice
from munch import Munch
import sys, os
from torch.utils.data import DataLoader
from torchvision import transforms
import time
import webdataset as wds
sys.path.append(os.getcwd())
num_iters = 50
args = Munch(input_size=224,
use_random_data=False,
shuffle_buffer=1,
batch_size=256,
num_workers=8,
tpu_cores=None)
url = "http://storage.googleapis.com/nvdata-openimages/openimages-train-{000000..000554}.tar"
url = f"pipe:curl -L -s {url} || true"
def identity(x):
return x
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
preproc = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
dataset = (
wds.Dataset(url)
.shuffle(100)
.decode("pil")
.to_tuple("jpg;png", "json")
.map_tuple(preproc, identity)
.batched(args.batch_size)
)
dataloader = DataLoader(dataset, num_workers=args.num_workers, batch_size=None)
start_time = time.time()
for i, sample in enumerate(islice(dataloader, 0, num_iters)):
print(f'\r{i}', end='')
print(f'\nend time={time.time() - start_time}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment