Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AdityaSoni19031997/eb112580d5a11aa86d639498276e47f2 to your computer and use it in GitHub Desktop.
Save AdityaSoni19031997/eb112580d5a11aa86d639498276e47f2 to your computer and use it in GitHub Desktop.
Modified code with the article "How to Build a Streaming DataLoader with PyTorch" at https://medium.com/speechmatics/how-to-build-a-streaming-dataloader-with-pytorch-a66dd891d9dd.
import random
from itertools import chain, cycle, islice
import torch.utils.data as data
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import time
import torch
import numpy as np
# added by CCJ;
def count(start=0, step=1):
# count(10) --> 10 11 12 13 14 ...
# count(2.5, 0.5) -> 2.5 3.0 3.5 ...
n = start
while True:
yield n
n += step
#> see:https://gist.github.com/david-macleod/2b933d28fd3ac09766785728ee191f09
def plot_timings(loader, n_batches, model_time=0.2, max_time=2.5):
fig, ax = plt.subplots()
ax.set_axisbelow(True)
ax.yaxis.grid(which="major", color='black', linewidth=1)
zero_time = time.time()
worker_ids = {}
worker_count = count() # added by CCJ;
for result in islice(loader, n_batches):
start = time.time()
time.sleep(model_time)
end = time.time()
# check if already batched
if isinstance(result[0], torch.Tensor):
result = zip(*result)
batch = []
#print ('result = ', result)
for item in result:
data, worker, t1, t2 = tuple(map(scalar, item))
#print ("processing worker id = %d"% worker )
# fix worker position in plot
if worker != -1:
if worker not in worker_ids:
worker_ids[worker] = next(worker_count)
worker = worker_ids[worker]
plot_time_box(data, worker, t1-zero_time, t2-zero_time, ax)
batch.append(data)
batch_str = ",".join(map(str, batch))
plot_time_box(batch_str, -1, start-zero_time, end-zero_time, ax, color='firebrick')
max_worker = len(worker_ids) - 1
ax.set_xlim(0, max_time)
ax.set_ylim(-1.5, max_worker + 0.5)
ax.set_xticks(np.arange(0, max_time, 0.2))
ax.set_yticks(np.arange(-1, max_worker+1, 1))
ax.set_yticklabels([])
ax.tick_params(axis='y', colors=(0,0,0,0))
fig.set_figwidth(16)
fig.set_figheight((max_worker + 2) * 0.5)
ax.xaxis.label.set_color('gray')
ax.tick_params(axis='x', colors='gray')
for spine in ax.spines.values():
spine.set_edgecolor((0,0,0,0))
# for showing image
#plt.show()
def scalar(x):
return x.item() if hasattr(x, 'item') else x
def plot_time_box(data, worker, t1, t2, ax, color='steelblue'):
x = t1
y = worker - 0.25
w = t2 - t1
h = 0.6
rect = Rectangle((x, y), w, h, linewidth=2, edgecolor='black',facecolor=color)
ax.add_patch(rect)
ax.text(x + (w * 0.5), y + (h * 0.5), str(data), va='center', ha='center', color='white', weight='bold')
class MyIterableDatasetV5(data.IterableDataset):
def __init__(self, data_list, batch_size):
self.data_list = data_list
self.batch_size = batch_size
@property
def shuffled_data_list(self):
return random.sample(self.data_list, len(self.data_list))
def process_data(self, mydata):
for x in mydata:
worker = torch.utils.data.get_worker_info()
worker_id = worker.id if worker is not None else -1
start = time.time()
time.sleep(0.1)
end = time.time()
yield x, worker_id, start, end
def get_stream(self, data_list):
return chain.from_iterable(map(self.process_data, cycle(data_list)))
def get_streams(self):
return zip(*[self.get_stream(self.shuffled_data_list) for _ in range(self.batch_size)])
def __iter__(self):
return self.get_streams()
if 0:
my_data_list = [[10, 11, 12, 13],
[20, 21, 22, 23],
[30, 31, 32, 33],
[40, 41, 42, 43],
[50, 51, 52, 53],
[60, 61, 62, 63],
[70, 71, 72, 73],
[80, 81, 82, 83],
[90, 91, 92, 93],
]
my_batch_size = 4
my_num_workers = 2
iterable_dataset = MyIterableDatasetV5(my_data_list, batch_size = my_batch_size)
loader = data.DataLoader(iterable_dataset, batch_size=None, num_workers=my_num_workers)
plot_timings(loader, model_time=0.2, n_batches = my_batch_size)
class MyIterableDatasetV7(data.IterableDataset):
def __init__(self, data_list, batch_size):
self.data_list = data_list
self.batch_size = batch_size
@property
def shuffled_data_list(self):
return random.sample(self.data_list, len(self.data_list))
def process_data(self, mydata):
for x in mydata:
worker = torch.utils.data.get_worker_info()
worker_id = id(self) if worker is not None else -1
start = time.time()
time.sleep(0.1)
end = time.time()
yield x, worker_id, start, end
def get_stream(self, data_list):
return chain.from_iterable(map(self.process_data, cycle(data_list)))
def get_streams(self):
return zip(*[self.get_stream(self.shuffled_data_list) for _ in range(self.batch_size)])
def __iter__(self):
return self.get_streams()
@classmethod
def split_datasets(cls, data_list, batch_size, max_workers):
for n in range(max_workers, 0, -1):
if batch_size % n ==0:
num_workers = n
break
split_size = batch_size // num_workers
return [cls(data_list, batch_size=split_size) for _ in range(0, num_workers)]
class multiStreamDataLoader:
def __init__(self, datasets):
self.datasets = datasets
def get_stream_loaders(self):
return zip(*[torch.utils.data.DataLoader(dataset, num_workers=1, batch_size=None) for dataset in self.datasets])
def __iter__(self):
for batch_parts in self.get_stream_loaders():
yield list(chain(* batch_parts))
if 1:
my_data_list = [
[10, 11, 12, 13],
[20, 21, 22, 23],
[30, 31, 32, 33],
[40, 41, 42, 43],
[50, 51, 52, 53],
[60, 61, 62, 63],
[70, 71, 72, 73],
[80, 81, 82, 83],
]
my_batch_size = 4
my_num_workers = 4
print ('[***] iterable dataset with %d workers' % my_num_workers)
dataset_tmp = MyIterableDatasetV7.split_datasets(my_data_list, batch_size = my_batch_size, max_workers= my_num_workers)
loader = multiStreamDataLoader(dataset_tmp)
plot_timings(loader, model_time=0.2, n_batches = 6)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment