Skip to content

Instantly share code, notes, and snippets.

Last active April 22, 2020 19:18
Show Gist options
  • Save isarandi/a72b3e5c1b1d3e40eb857a01d91926f9 to your computer and use it in GitHub Desktop.
Save isarandi/a72b3e5c1b1d3e40eb857a01d91926f9 to your computer and use it in GitHub Desktop.
Parallel input pipeline as a TensorFlow Dataset
def parallel_map_as_tf_dataset(
fun, iterable, *, output_types=None, output_shapes=None, shuffle_before_each_epoch=False,
extra_args=None, n_workers=10, n_epochs=None):
"""Maps `fun` to each element of `iterable` and wraps the resulting sequence as
as a TF Dataset. Elements are processed by parallel workers using multiprocessing.
fun: A function that takes an element from `iterable` plus `extra_args` and returns a sequence of
numpy arrays.
iterable: An iterable holding the input objects, which can be any Python objects, not just numpy arrays.
output_types: A list of types, describing each output numpy array from `fun`.
If None, then it is automatically determined by calling `fun` on the first element.
output_shapes: A list of array shapes, describing each output numpy array from `fun`.
If None, then it is automatically determined by calling `fun` on the first element.
shuffle_before_each_epoch: Shuffle the input elements before each epoch. Converts
`iterable` to a list internally.
extra_args: extra arguments in addition to an element from `iterable`,
given to `fun` at each call
n_workers: Number of worker processes for parallelity.
n_epochs: Number of times to iterate over the `iterable`.
Returns: based on the arrays returned by `fun`.
extra_args = extra_args or []
# Automatically determine the output tensor types and shapes by calling the function on
# the first element
first_elem, iterable = peek(iterable)
if output_types is None or output_shapes is None:
sample_output = fun(first_elem, *extra_args)
output_shapes, output_types = get_shapes_and_tf_dtypes(sample_output)
pool = get_pool(n_workers)
semaphore = threading.Semaphore(256)
q = queue.Queue()
if n_epochs is None:
epoch_counter = itertools.count()
epoch_counter = range(n_epochs)
if shuffle_before_each_epoch:
iterable = list(iterable)
def producer():
for _ in epoch_counter:
if shuffle_before_each_epoch:
for item in iterable:
pool.apply_async(fun, (item, *extra_args), callback=q.put)
producer_thread = threading.Thread(target=producer, daemon=True)
def consumer():
while True:
result = q.get()
if result is None:
yield tuple(result)
return, output_types, output_shapes)
def example_use():
# [...]
dataset = parallel_map_as_tf_dataset(
load_fn, examples, shuffle_before_each_epoch=True,
extra_args=('bla1', 'bla2'), n_workers=n_workers, n_epochs=n_epochs)
# Conceptual equivalent of the above:
# data = []
# for item in examples:
# data.append(load_fn(item, 'bla1', 'bla2'))
is_training = True
dataset = dataset.batch(batch_size, drop_remainder=is_training)
dataset = dataset.apply('/gpu:0', 3))
iterator = dataset.make_one_shot_iterator()
batch_tensors = iterator.get_next()
# [...]
def peek(iterable):
iterator = iter(iterable)
head = next(iterator)
return head, itertools.chain([head], iterator)
def get_shapes_and_tf_dtypes(thing):
if not isinstance(thing, (list, tuple)):
thing = (thing,)
arrays = [np.asanyarray(a) for a in thing]
tf_types = [tf.as_dtype(a.dtype) for a in arrays]
shapes = [tf.TensorShape(a.shape) for a in arrays]
return tuple(shapes), tuple(tf_types)
_pool = None
def get_pool(n_workers_if_uninitialized):
global _pool
if _pool is None:
ctx = multiprocessing.get_context('spawn')
# important to use 'spawn', because 'fork' would mean the whole memory is (lazily) copied
# then due to copy-on-write semantics, it gets duplicated when the parent changes anything
_pool = ctx.Pool(n_workers_if_uninitialized, initializer=init_worker_process)
return _pool
def init_worker_process():
signal.signal(signal.SIGINT, signal.SIG_IGN)
seed = generate_seed()
def generate_seed():
pid = os.getpid()
s = int(time.time())
return abs(((s * 181) * ((pid - 83) * 359)) % 104729)
def terminate_on_parent_death():
prctl = ctypes.CDLL("").prctl
Copy link

isarandi commented Apr 22, 2020

I published my more complicated version ( that can handle deterministic shuffling and augmentation, plus resuming correctly from an arbitrary point in the sequence (helps to seamlessly resume from a training checkpoint).

I'm fine with the MIT License in this case, but please remember that it is not appropriate to arbitrarily add licenses to other people's code, always seek their approval first.

But anyway, I'm really glad that it's useful for you, perhaps you can check out if this new version works for you better.

Copy link

ljn917 commented Apr 22, 2020 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment