Skip to content

Instantly share code, notes, and snippets.

@Dref360
Last active June 6, 2017 20:25
Show Gist options
  • Save Dref360/30e0eda6f8f748460057ab2bb2aa835a to your computer and use it in GitHub Desktop.
Save Dref360/30e0eda6f8f748460057ab2bb2aa835a to your computer and use it in GitHub Desktop.
Ordered Multiprocess executor to be used in Keras. (Looks like Pytorch's dataloader)
import os
import time
from concurrent.futures import ProcessPoolExecutor
from itertools import cycle
from queue import Queue
from threading import Thread, Event
from keras.engine.training import GeneratorEnqueuer
class Dataset():
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
class ExampleDataset(Dataset):
def __getitem__(self, index):
time.sleep(1)
return os.getpid(), index
def __len__(self):
return 100
class MultiProcessExecutor():
def __init__(self, dataset, workers=1, max_q_size=5):
self.workers = workers
self.executor = ProcessPoolExecutor(self.workers)
self.dataset = dataset
self.queue = Queue(max_q_size)
self.run_thread = None
self.stop_signal = Event()
def is_running(self):
return self.stop_signal.is_set()
def start(self):
self.run_thread = Thread(target=self.run)
self.run_thread.daemon = True
self.run_thread.start()
def run(self):
""" This will queue up tasks in order """
indexes = cycle(range(len(self.dataset)))
for i in indexes:
if self.stop_signal.is_set():
return
self.queue.put(self.executor.submit(self.dataset.__getitem__, [i]), block=True)
def get_item(self):
try:
while True:
yield self.queue.get(block=True).result()
except Exception as e:
self.stop()
print('MultiProcessExecutor has stopped because of :', type(e).__name__, str(e), flush=True)
raise StopIteration
def stop(self):
self.executor.shutdown()
self.stop_signal.set()
with self.queue.mutex:
self.queue.queue.clear()
self.queue.unfinished_tasks = 0
self.queue.not_full.notify()
self.run_thread.join()
dataset = ExampleDataset()
executor = MultiProcessExecutor(dataset)
executor.start()
getter = executor.get_item()
start = time.time()
for i in range(100):
result = next(getter)
print("Took executor", time.time() - start)
"""
Comparing to Keras
"""
def keras_gen():
while True:
time.sleep(1)
yield os.getpid()
qu = GeneratorEnqueuer(keras_gen(), pickle_safe=True)
qu.start(5, 10)
start = time.time()
for i in range(100):
while not qu.queue.qsize():
time.sleep(0.5)
result = qu.queue.get()
print("Took Keras", time.time() - start)
@Dref360
Copy link
Author

Dref360 commented Jun 6, 2017

This is as expensive as Keras GeneratorEnqueuer while preserving orders.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment