Skip to content

Instantly share code, notes, and snippets.

Created June 9, 2016 11:28
Show Gist options
  • Save gzuidhof/3ca6a5c9290560c08d79157a6bda34ad to your computer and use it in GitHub Desktop.
Save gzuidhof/3ca6a5c9290560c08d79157a6bda34ad to your computer and use it in GitHub Desktop.
Parallel Batch Iterator, for preparing batches in different threads of processes.
from __future__ import division
import math
from multiprocessing import Process, Queue, JoinableQueue, Value
from threading import Thread
from functools import partial
class ParallelBatchIterator(object):
Uses a producer-consumer model to prepare batches on the CPU in different processes or threads (while you are training on the GPU).
Constructor arguments:
batch_generator: function which can be called to yield a new batch.
X: input for the batch_generator (could be for instance filenames)
ordered: boolean (default=False), whether the order of the batches matters
batch_size: integer (default=1), amount of points in one batch
multiprocess: boolean (default=True), multiprocess instead of multithrea
n_producers: integer (default=4), amount of producers (threads or processes)
max_queue_size: integer (default=4*n_producers)
def __init__(self, batch_generator, X, batch_size=1, ordered=False, multiprocess=True, n_producers=4, max_queue_size=None):
self.generator = batch_generator
self.ordered = ordered
self.multiprocess = multiprocess
self.n_producers = n_producers
self.X = X
self.batch_size = batch_size
if max_queue_size is None:
self.max_queue_size = n_producers*4
self.max_queue_size = max_queue_size
def __call__(self):
return self
def __iter__(self):
queue = JoinableQueue(maxsize=self.max_queue_size)
n_batches, job_queue = self._start_producers(queue)
# Run as consumer (read items from queue, in current thread)
for x in xrange(n_batches):
item = queue.get()
yield item # Yield the item to the consumer (user)
def __len__(self):
return math.ceil(len(self.X)/self.batch_size)
def _start_producers(self, result_queue):
jobs = Queue()
n_workers = self.n_producers
batch_count = 0
# Flag used for keeping values in queue in order
last_queued_job = Value('i', -1)
chunks = _chunks(self.X,self.batch_size)
# Add jobs to queue
for job_index, X_batch in enumerate(chunks):
batch_count += 1
jobs.put( (job_index,X_batch) )
# Add poison pills to queue (to signal workers to stop)
for i in xrange(n_workers):
# Define producer function
produce = partial(_produce_helper,
# Start worker processes or threads
for i in xrange(n_workers):
name = "ParallelBatchIterator worker {0}".format(i)
if self.multiprocess:
p = Process(target=produce, args=(i,), name=name)
p = Thread(target=produce, args=(i,), name=name)
# Make the process daemon, so the main process can die without these finishing
p.daemon = True
return batch_count, jobs
def _produce_helper(id, generator, jobs, result_queue, last_queued_job, ordered):
What one worker executes, defined as a top level function as this is required for the Windows platform.
while True:
job_index, task = jobs.get()
# Kill the worker if there is no more work
# (This is a poison pill)
if job_index == -1 and task is None:
result = generator(task)
# Put result onto the 'done'-queue
while True:
# My turn to add job result (to keep it in order)?
if last_queued_job.value == job_index-1 or not ordered:
with last_queued_job.get_lock():
last_queued_job.value += 1
def _chunks(l, n):
""" Yield successive n-sized chunks from l.
for i in xrange(0, len(l), n):
yield l[i:i+n]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment