Skip to content

Instantly share code, notes, and snippets.

@isarandi
Last active April 22, 2020 19:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save isarandi/a72b3e5c1b1d3e40eb857a01d91926f9 to your computer and use it in GitHub Desktop.
Save isarandi/a72b3e5c1b1d3e40eb857a01d91926f9 to your computer and use it in GitHub Desktop.
Parallel input pipeline as a TensorFlow Dataset
def parallel_map_as_tf_dataset(
fun, iterable, *, output_types=None, output_shapes=None, shuffle_before_each_epoch=False,
extra_args=None, n_workers=10, n_epochs=None):
"""Maps `fun` to each element of `iterable` and wraps the resulting sequence as
as a TF Dataset. Elements are processed by parallel workers using multiprocessing.
Args:
fun: A function that takes an element from `iterable` plus `extra_args` and returns a sequence of
numpy arrays.
iterable: An iterable holding the input objects, which can be any Python objects, not just numpy arrays.
output_types: A list of types, describing each output numpy array from `fun`.
If None, then it is automatically determined by calling `fun` on the first element.
output_shapes: A list of array shapes, describing each output numpy array from `fun`.
If None, then it is automatically determined by calling `fun` on the first element.
shuffle_before_each_epoch: Shuffle the input elements before each epoch. Converts
`iterable` to a list internally.
extra_args: extra arguments in addition to an element from `iterable`,
given to `fun` at each call
n_workers: Number of worker processes for parallelity.
n_epochs: Number of times to iterate over the `iterable`.
Returns:
tf.data.Dataset based on the arrays returned by `fun`.
"""
extra_args = extra_args or []
# Automatically determine the output tensor types and shapes by calling the function on
# the first element
first_elem, iterable = peek(iterable)
if output_types is None or output_shapes is None:
sample_output = fun(first_elem, *extra_args)
output_shapes, output_types = get_shapes_and_tf_dtypes(sample_output)
pool = get_pool(n_workers)
semaphore = threading.Semaphore(256)
q = queue.Queue()
if n_epochs is None:
epoch_counter = itertools.count()
else:
epoch_counter = range(n_epochs)
if shuffle_before_each_epoch:
iterable = list(iterable)
def producer():
for _ in epoch_counter:
if shuffle_before_each_epoch:
random.shuffle(iterable)
for item in iterable:
semaphore.acquire()
pool.apply_async(fun, (item, *extra_args), callback=q.put)
q.put(None)
producer_thread = threading.Thread(target=producer, daemon=True)
producer_thread.start()
def consumer():
while True:
result = q.get()
if result is None:
return
else:
semaphore.release()
yield tuple(result)
return tf.data.Dataset.from_generator(consumer, output_types, output_shapes)
def example_use():
# [...]
dataset = parallel_map_as_tf_dataset(
load_fn, examples, shuffle_before_each_epoch=True,
extra_args=('bla1', 'bla2'), n_workers=n_workers, n_epochs=n_epochs)
# Conceptual equivalent of the above:
# data = []
# for item in examples:
# data.append(load_fn(item, 'bla1', 'bla2'))
is_training = True
dataset = dataset.batch(batch_size, drop_remainder=is_training)
dataset = dataset.apply(tf.data.experimental.prefetch_to_device('/gpu:0', 3))
iterator = dataset.make_one_shot_iterator()
batch_tensors = iterator.get_next()
# [...]
def peek(iterable):
iterator = iter(iterable)
head = next(iterator)
return head, itertools.chain([head], iterator)
def get_shapes_and_tf_dtypes(thing):
if not isinstance(thing, (list, tuple)):
thing = (thing,)
arrays = [np.asanyarray(a) for a in thing]
tf_types = [tf.as_dtype(a.dtype) for a in arrays]
shapes = [tf.TensorShape(a.shape) for a in arrays]
return tuple(shapes), tuple(tf_types)
_pool = None
def get_pool(n_workers_if_uninitialized):
global _pool
if _pool is None:
ctx = multiprocessing.get_context('spawn')
# important to use 'spawn', because 'fork' would mean the whole memory is (lazily) copied
# then due to copy-on-write semantics, it gets duplicated when the parent changes anything
_pool = ctx.Pool(n_workers_if_uninitialized, initializer=init_worker_process)
return _pool
def init_worker_process():
terminate_on_parent_death()
signal.signal(signal.SIGINT, signal.SIG_IGN)
seed = generate_seed()
np.random.seed(seed)
random.seed(seed)
def generate_seed():
pid = os.getpid()
s = int(time.time())
return abs(((s * 181) * ((pid - 83) * 359)) % 104729)
def terminate_on_parent_death():
prctl = ctypes.CDLL("libc.so.6").prctl
PR_SET_PDEATHSIG = 1
prctl(PR_SET_PDEATHSIG, signal.SIGTERM)
@ljn917
Copy link

ljn917 commented Apr 18, 2020

Hi, thank you very much for your code first. When I tried to run your code, I found the size of returned dataset is usually less than the size of input dataset. This is because at line 55, when None is sent, it is very likely that workers have not finished yet. The easiest way to fix is to wait until semaphore._value == 256 (requires python >= 3.6) before sending None.

In addition, could you please add a license to your code so that I can comfortably integrate it into my project? MIT/BSD type license would be much appreciated.

Thank you for your work again.

@isarandi
Copy link
Author

Yes, I use a newer version of this code with several enhancements. I'll upload it soon.

@ljn917
Copy link

ljn917 commented Apr 22, 2020

https://gist.github.com/ljn917/e3422fc8803590691c718262acfef1dc

I uploaded my modifications to the gist above. I also added MIT license to it. If you don't like the license, please feel free to let me know.

@isarandi
Copy link
Author

isarandi commented Apr 22, 2020

I published my more complicated version (https://gist.github.com/isarandi/fb65138c66fa61218e0bce827cb30127) that can handle deterministic shuffling and augmentation, plus resuming correctly from an arbitrary point in the sequence (helps to seamlessly resume from a training checkpoint).

I'm fine with the MIT License in this case, but please remember that it is not appropriate to arbitrarily add licenses to other people's code, always seek their approval first.

But anyway, I'm really glad that it's useful for you, perhaps you can check out if this new version works for you better.

@ljn917
Copy link

ljn917 commented Apr 22, 2020 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment