Last active
April 22, 2020 19:18
-
-
Save isarandi/a72b3e5c1b1d3e40eb857a01d91926f9 to your computer and use it in GitHub Desktop.
Parallel input pipeline as a TensorFlow Dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def parallel_map_as_tf_dataset( | |
fun, iterable, *, output_types=None, output_shapes=None, shuffle_before_each_epoch=False, | |
extra_args=None, n_workers=10, n_epochs=None): | |
"""Maps `fun` to each element of `iterable` and wraps the resulting sequence as | |
as a TF Dataset. Elements are processed by parallel workers using multiprocessing. | |
Args: | |
fun: A function that takes an element from `iterable` plus `extra_args` and returns a sequence of | |
numpy arrays. | |
iterable: An iterable holding the input objects, which can be any Python objects, not just numpy arrays. | |
output_types: A list of types, describing each output numpy array from `fun`. | |
If None, then it is automatically determined by calling `fun` on the first element. | |
output_shapes: A list of array shapes, describing each output numpy array from `fun`. | |
If None, then it is automatically determined by calling `fun` on the first element. | |
shuffle_before_each_epoch: Shuffle the input elements before each epoch. Converts | |
`iterable` to a list internally. | |
extra_args: extra arguments in addition to an element from `iterable`, | |
given to `fun` at each call | |
n_workers: Number of worker processes for parallelity. | |
n_epochs: Number of times to iterate over the `iterable`. | |
Returns: | |
tf.data.Dataset based on the arrays returned by `fun`. | |
""" | |
extra_args = extra_args or [] | |
# Automatically determine the output tensor types and shapes by calling the function on | |
# the first element | |
first_elem, iterable = peek(iterable) | |
if output_types is None or output_shapes is None: | |
sample_output = fun(first_elem, *extra_args) | |
output_shapes, output_types = get_shapes_and_tf_dtypes(sample_output) | |
pool = get_pool(n_workers) | |
semaphore = threading.Semaphore(256) | |
q = queue.Queue() | |
if n_epochs is None: | |
epoch_counter = itertools.count() | |
else: | |
epoch_counter = range(n_epochs) | |
if shuffle_before_each_epoch: | |
iterable = list(iterable) | |
def producer(): | |
for _ in epoch_counter: | |
if shuffle_before_each_epoch: | |
random.shuffle(iterable) | |
for item in iterable: | |
semaphore.acquire() | |
pool.apply_async(fun, (item, *extra_args), callback=q.put) | |
q.put(None) | |
producer_thread = threading.Thread(target=producer, daemon=True) | |
producer_thread.start() | |
def consumer(): | |
while True: | |
result = q.get() | |
if result is None: | |
return | |
else: | |
semaphore.release() | |
yield tuple(result) | |
return tf.data.Dataset.from_generator(consumer, output_types, output_shapes) | |
def example_use(): | |
# [...] | |
dataset = parallel_map_as_tf_dataset( | |
load_fn, examples, shuffle_before_each_epoch=True, | |
extra_args=('bla1', 'bla2'), n_workers=n_workers, n_epochs=n_epochs) | |
# Conceptual equivalent of the above: | |
# data = [] | |
# for item in examples: | |
# data.append(load_fn(item, 'bla1', 'bla2')) | |
is_training = True | |
dataset = dataset.batch(batch_size, drop_remainder=is_training) | |
dataset = dataset.apply(tf.data.experimental.prefetch_to_device('/gpu:0', 3)) | |
iterator = dataset.make_one_shot_iterator() | |
batch_tensors = iterator.get_next() | |
# [...] | |
def peek(iterable): | |
iterator = iter(iterable) | |
head = next(iterator) | |
return head, itertools.chain([head], iterator) | |
def get_shapes_and_tf_dtypes(thing): | |
if not isinstance(thing, (list, tuple)): | |
thing = (thing,) | |
arrays = [np.asanyarray(a) for a in thing] | |
tf_types = [tf.as_dtype(a.dtype) for a in arrays] | |
shapes = [tf.TensorShape(a.shape) for a in arrays] | |
return tuple(shapes), tuple(tf_types) | |
_pool = None | |
def get_pool(n_workers_if_uninitialized): | |
global _pool | |
if _pool is None: | |
ctx = multiprocessing.get_context('spawn') | |
# important to use 'spawn', because 'fork' would mean the whole memory is (lazily) copied | |
# then due to copy-on-write semantics, it gets duplicated when the parent changes anything | |
_pool = ctx.Pool(n_workers_if_uninitialized, initializer=init_worker_process) | |
return _pool | |
def init_worker_process(): | |
terminate_on_parent_death() | |
signal.signal(signal.SIGINT, signal.SIG_IGN) | |
seed = generate_seed() | |
np.random.seed(seed) | |
random.seed(seed) | |
def generate_seed(): | |
pid = os.getpid() | |
s = int(time.time()) | |
return abs(((s * 181) * ((pid - 83) * 359)) % 104729) | |
def terminate_on_parent_death(): | |
prctl = ctypes.CDLL("libc.so.6").prctl | |
PR_SET_PDEATHSIG = 1 | |
prctl(PR_SET_PDEATHSIG, signal.SIGTERM) |
I published my more complicated version (https://gist.github.com/isarandi/fb65138c66fa61218e0bce827cb30127) that can handle deterministic shuffling and augmentation, plus resuming correctly from an arbitrary point in the sequence (helps to seamlessly resume from a training checkpoint).
I'm fine with the MIT License in this case, but please remember that it is not appropriate to arbitrarily add licenses to other people's code, always seek their approval first.
But anyway, I'm really glad that it's useful for you, perhaps you can check out if this new version works for you better.
Thank you very much for your updated version and your kindness with my
licensing issues. Really appreciate your help.
…On Wed, Apr 22, 2020, 3:10 PM István Sárándi ***@***.***> wrote:
***@***.**** commented on this gist.
------------------------------
I published my more complicated version (
https://gist.github.com/isarandi/fb65138c66fa61218e0bce827cb30127) that
can handle deterministic shuffling and augmentation, plus resuming
correctly from an arbitrary point in the sequence (helps to seamlessly
resume from a training checkpoint).
I'm fine with the MIT License in this case, but please remember that it is
not appropriate to arbitrarily add licenses to other people's code, always
seek their approval first.
But anyway, I'm really glad that it's useful for you, perhaps you can
check out if this new version works for you better.
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<https://gist.github.com/a72b3e5c1b1d3e40eb857a01d91926f9#gistcomment-3264926>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AALDTHQZMDTDLHQUX3ZUJJDRN46JNANCNFSM4MLERCNQ>
.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
https://gist.github.com/ljn917/e3422fc8803590691c718262acfef1dc
I uploaded my modifications to the gist above. I also added MIT license to it. If you don't like the license, please feel free to let me know.