Skip to content

Instantly share code, notes, and snippets.

Last active April 22, 2020 19:18
Show Gist options
  • Save isarandi/a72b3e5c1b1d3e40eb857a01d91926f9 to your computer and use it in GitHub Desktop.
Save isarandi/a72b3e5c1b1d3e40eb857a01d91926f9 to your computer and use it in GitHub Desktop.
Parallel input pipeline as a TensorFlow Dataset
def parallel_map_as_tf_dataset(
fun, iterable, *, output_types=None, output_shapes=None, shuffle_before_each_epoch=False,
extra_args=None, n_workers=10, n_epochs=None):
"""Maps `fun` to each element of `iterable` and wraps the resulting sequence as
as a TF Dataset. Elements are processed by parallel workers using multiprocessing.
fun: A function that takes an element from `iterable` plus `extra_args` and returns a sequence of
numpy arrays.
iterable: An iterable holding the input objects, which can be any Python objects, not just numpy arrays.
output_types: A list of types, describing each output numpy array from `fun`.
If None, then it is automatically determined by calling `fun` on the first element.
output_shapes: A list of array shapes, describing each output numpy array from `fun`.
If None, then it is automatically determined by calling `fun` on the first element.
shuffle_before_each_epoch: Shuffle the input elements before each epoch. Converts
`iterable` to a list internally.
extra_args: extra arguments in addition to an element from `iterable`,
given to `fun` at each call
n_workers: Number of worker processes for parallelity.
n_epochs: Number of times to iterate over the `iterable`.
Returns: based on the arrays returned by `fun`.
extra_args = extra_args or []
# Automatically determine the output tensor types and shapes by calling the function on
# the first element
first_elem, iterable = peek(iterable)
if output_types is None or output_shapes is None:
sample_output = fun(first_elem, *extra_args)
output_shapes, output_types = get_shapes_and_tf_dtypes(sample_output)
pool = get_pool(n_workers)
semaphore = threading.Semaphore(256)
q = queue.Queue()
if n_epochs is None:
epoch_counter = itertools.count()
epoch_counter = range(n_epochs)
if shuffle_before_each_epoch:
iterable = list(iterable)
def producer():
for _ in epoch_counter:
if shuffle_before_each_epoch:
for item in iterable:
pool.apply_async(fun, (item, *extra_args), callback=q.put)
producer_thread = threading.Thread(target=producer, daemon=True)
def consumer():
while True:
result = q.get()
if result is None:
yield tuple(result)
return, output_types, output_shapes)
def example_use():
# [...]
dataset = parallel_map_as_tf_dataset(
load_fn, examples, shuffle_before_each_epoch=True,
extra_args=('bla1', 'bla2'), n_workers=n_workers, n_epochs=n_epochs)
# Conceptual equivalent of the above:
# data = []
# for item in examples:
# data.append(load_fn(item, 'bla1', 'bla2'))
is_training = True
dataset = dataset.batch(batch_size, drop_remainder=is_training)
dataset = dataset.apply('/gpu:0', 3))
iterator = dataset.make_one_shot_iterator()
batch_tensors = iterator.get_next()
# [...]
def peek(iterable):
iterator = iter(iterable)
head = next(iterator)
return head, itertools.chain([head], iterator)
def get_shapes_and_tf_dtypes(thing):
if not isinstance(thing, (list, tuple)):
thing = (thing,)
arrays = [np.asanyarray(a) for a in thing]
tf_types = [tf.as_dtype(a.dtype) for a in arrays]
shapes = [tf.TensorShape(a.shape) for a in arrays]
return tuple(shapes), tuple(tf_types)
_pool = None
def get_pool(n_workers_if_uninitialized):
global _pool
if _pool is None:
ctx = multiprocessing.get_context('spawn')
# important to use 'spawn', because 'fork' would mean the whole memory is (lazily) copied
# then due to copy-on-write semantics, it gets duplicated when the parent changes anything
_pool = ctx.Pool(n_workers_if_uninitialized, initializer=init_worker_process)
return _pool
def init_worker_process():
signal.signal(signal.SIGINT, signal.SIG_IGN)
seed = generate_seed()
def generate_seed():
pid = os.getpid()
s = int(time.time())
return abs(((s * 181) * ((pid - 83) * 359)) % 104729)
def terminate_on_parent_death():
prctl = ctypes.CDLL("").prctl
Copy link

ljn917 commented Apr 22, 2020 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment