-
-
Save ptrblck/e94ebd9d5b75a561f932ef6303e8c7b9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torch.utils.data import Dataset, DataLoader | |
import time | |
#from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor | |
class MyDataset(Dataset): | |
def __init__(self): | |
# set file paths | |
pass | |
def __getitem__(self, index): | |
# load data | |
x = torch.randn(3, 224, 224) | |
y = torch.randint(0, 10, (1,)) | |
# simulate workflow | |
for _ in range(10000): | |
torch.sqrt(torch.randn(100)) | |
return x, y | |
def __len__(self): | |
return 1000 | |
def create_dataloader(): | |
dataset = MyDataset() | |
loader = DataLoader( | |
dataset, | |
batch_size=10, | |
num_workers=4, | |
pin_memory=True, | |
shuffle=True | |
) | |
return loader | |
#pool = ThreadPoolExecutor(5) | |
# Current workflow | |
# get first dataloader | |
training_data = create_dataloader() | |
# multiple epochs | |
for epoch in range(4): | |
# prefetch next dataloader | |
#next_data = pool.submit(create_dataloader, ...) | |
t0 = time.time() | |
for step, batch in enumerate(training_data): | |
t1 = time.time() | |
print('epoch {}, batch {}, data loading {}s'.format( | |
epoch, step, (t1 - t0))) | |
t0 = time.time() | |
training_data = create_dataloader() | |
print('creating loader in {}'.format(time.time() - t0)) | |
# Manual usage of iterators | |
training_data = create_dataloader() | |
loader_iter = training_data.__iter__() | |
prefetched_iter = create_dataloader().__iter__() #pool.submit(create_dataloader().__iter__) | |
# multiple epochs | |
epoch = 0 | |
while True: | |
# prefetch next dataloader | |
#next_data = pool.submit(create_dataloader, ...) | |
t0 = time.time() | |
for step, batch in enumerate(loader_iter): | |
t1 = time.time() | |
print('epoch {}, batch {}, data loading {}s'.format( | |
epoch, step, (t1 - t0))) | |
#data, target = batch[0].to('cuda') | |
t0 = time.time() | |
else: | |
print('enter in {}'.format(time.time() - t0)) | |
loader_iter = prefetched_iter #.result(timeout=None) | |
print('prefetched iter in {}'.format(time.time() - t0)) | |
training_data = create_dataloader() | |
prefetched_iter = training_data.__iter__() #pool.submit(training_data.__iter__) | |
print('prefetching done in {}'.format(time.time() - t0)) | |
epoch += 1 | |
if epoch == 4: | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment