Skip to content

Instantly share code, notes, and snippets.

@ptrblck
Created August 10, 2019 23:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ptrblck/e94ebd9d5b75a561f932ef6303e8c7b9 to your computer and use it in GitHub Desktop.
Save ptrblck/e94ebd9d5b75a561f932ef6303e8c7b9 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import time
#from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
class MyDataset(Dataset):
def __init__(self):
# set file paths
pass
def __getitem__(self, index):
# load data
x = torch.randn(3, 224, 224)
y = torch.randint(0, 10, (1,))
# simulate workflow
for _ in range(10000):
torch.sqrt(torch.randn(100))
return x, y
def __len__(self):
return 1000
def create_dataloader():
dataset = MyDataset()
loader = DataLoader(
dataset,
batch_size=10,
num_workers=4,
pin_memory=True,
shuffle=True
)
return loader
#pool = ThreadPoolExecutor(5)
# Current workflow
# get first dataloader
training_data = create_dataloader()
# multiple epochs
for epoch in range(4):
# prefetch next dataloader
#next_data = pool.submit(create_dataloader, ...)
t0 = time.time()
for step, batch in enumerate(training_data):
t1 = time.time()
print('epoch {}, batch {}, data loading {}s'.format(
epoch, step, (t1 - t0)))
t0 = time.time()
training_data = create_dataloader()
print('creating loader in {}'.format(time.time() - t0))
# Manual usage of iterators
training_data = create_dataloader()
loader_iter = training_data.__iter__()
prefetched_iter = create_dataloader().__iter__() #pool.submit(create_dataloader().__iter__)
# multiple epochs
epoch = 0
while True:
# prefetch next dataloader
#next_data = pool.submit(create_dataloader, ...)
t0 = time.time()
for step, batch in enumerate(loader_iter):
t1 = time.time()
print('epoch {}, batch {}, data loading {}s'.format(
epoch, step, (t1 - t0)))
#data, target = batch[0].to('cuda')
t0 = time.time()
else:
print('enter in {}'.format(time.time() - t0))
loader_iter = prefetched_iter #.result(timeout=None)
print('prefetched iter in {}'.format(time.time() - t0))
training_data = create_dataloader()
prefetched_iter = training_data.__iter__() #pool.submit(training_data.__iter__)
print('prefetching done in {}'.format(time.time() - t0))
epoch += 1
if epoch == 4:
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment