Skip to content

Instantly share code, notes, and snippets.

@isarandi
Last active March 2, 2022 18:04
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save isarandi/fb65138c66fa61218e0bce827cb30127 to your computer and use it in GitHub Desktop.
Save isarandi/fb65138c66fa61218e0bce827cb30127 to your computer and use it in GitHub Desktop.
MIT License
# Copyright (c) 2020 István Sárándi <sarandi@vision.rwth-aachen.de>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import ctypes
import itertools
import logging
import multiprocessing
import os
import queue
import signal
import threading
import numpy as np
import tensorflow as tf
def example_use():
# [...]
dataset = parallel_map_as_tf_dataset(
load_fn, examples, shuffle_before_each_epoch=True,
extra_args=('bla1', 'bla2'), n_workers=n_workers, n_completed_items=10, n_total_items=100,
rng=np.random.RandomState(42)
)
# Conceptual equivalent of the above:
# data = []
# for item in examples:
# data.append(load_fn(item, 'bla1', 'bla2'))
is_training = True
dataset = dataset.batch(batch_size, drop_remainder=is_training)
dataset = dataset.apply(tf.data.experimental.prefetch_to_device('/gpu:0', 3))
iterator = dataset.make_one_shot_iterator()
batch_tensors = iterator.get_next()
# [...]
def parallel_map_as_tf_dataset(
fun, iterable, *, output_types=None, output_shapes=None, shuffle_before_each_epoch=False,
extra_args=None, n_workers=10, rng=None, max_unconsumed=256,
n_completed_items=0, n_total_items=None):
"""Maps `fun` to each element of `iterable` and wraps the resulting sequence as
as a TF Dataset. Elements are processed by parallel workers using multiprocessing.
Special consideration is given to randomness to keep things deterministic.
The `rng` argument is the main starting numpy.RandomState. The shuffling is derived from this.
There is also a possibility to make `fun` random, through its last argument. `fun` will always
be called with a numpy.RandomState object, and it can use this to perform data augmentation
or similar processing. The RandomState objects given to each `fun` call are all different
and are derived deterministically from the main `rng`.
Args:
fun: A function that takes an element from seq, `extra_args` and a RandomState and returns
some numpy arrays.
seq: An iterable holding the inputs.
output_types: A list of types, describing each output numpy array from `fun`.
If None, then it is automatically determined by calling `fun` on the first element.
output_shapes: A list of array shapes, describing each output numpy array from `fun`.
If None, then it is automatically determined by calling `fun` on the first element.
shuffle_before_each_epoch: Shuffle the input elements before each epoch. Converts
`iterable` to a list internally.
extra_args: extra arguments in addition to an element from `seq`, given to `fun` at each
call
n_workers: Number of worker processes for parallelity.
rng: RandomState for shuffling and for randomizing `fun` through its last argument.
max_unconsumed: max number of items that can be under processing or in the finished buffer
at any time. By limiting this, we can limit the memory usage if `fun` finishes
much quicker than the results can be consumed.
n_completed_items: number of items that should be skipped at the beginnning, this is
intended as a way to help restoring from a checkpoint and resuming a the deterministic
training process.
n_total_items: The number of items to process in total (including the completed ones).
Returns:
tf.data.Dataset based on the arrays returned by `fun`.
"""
extra_args = extra_args or []
# Automatically determine the output tensor types and shapes by calling the function on
# the first element
first_elem, iterable = peek(iterable)
iterable = list(iterable)
if output_types is None or output_shapes is None:
sample_output = fun(first_elem, *extra_args, rng=np.random.RandomState(0))
output_shapes, output_types = get_shapes_and_tf_dtypes(sample_output)
items = iterate_repeatedly(iterable, shuffle_before_each_epoch, new_rng(rng))
# If we are restoring from a checkpoint and have already completed some
# training steps for towards that checkpoint, then we need to advance the RNG
# accordingly, to make the resuming seamless.
iter_rng = new_rng(rng)
advance_rng(iter_rng, n_completed_items)
logging.debug(f'n_total_items: {n_total_items}, n_completed_items: {n_completed_items}')
items = itertools.islice(items, n_completed_items, n_total_items)
if n_workers == 0:
def gen():
for item in items:
yield fun(item, *extra_args, new_rng(iter_rng))
logging.debug('ended')
else:
pool = get_pool(n_workers)
gen = parallel_map_as_generator(
fun, items, extra_args, pool, rng=iter_rng, max_unconsumed=max_unconsumed)
return tf.data.Dataset.from_generator(gen, output_types, output_shapes)
def parallel_map_as_generator(
fun, items, extra_args, pool, max_unconsumed=256, rng=None):
semaphore = threading.Semaphore(max_unconsumed)
q = queue.Queue()
end_of_sequence_marker = object()
def producer():
for i_item, item in enumerate(items):
semaphore.acquire()
q.put(pool.apply_async(fun, (item, *extra_args, new_rng(rng))))
q.put(end_of_sequence_marker)
def consumer():
while True:
future_or_end = q.get()
if future_or_end is end_of_sequence_marker:
return
else:
value = tuple(future_or_end.get())
semaphore.release()
yield value
producer_thread = threading.Thread(target=producer, daemon=True)
producer_thread.start()
return consumer
def peek(iterable):
iterator = iter(iterable)
head = next(iterator)
return head, itertools.chain([head], iterator)
def get_shapes_and_tf_dtypes(thing):
if not isinstance(thing, (list, tuple)):
thing = (thing,)
arrays = [np.asanyarray(a) for a in thing]
tf_types = [tf.as_dtype(a.dtype) for a in arrays]
shapes = [tf.TensorShape(a.shape) for a in arrays]
return tuple(shapes), tuple(tf_types)
def iterate_repeatedly(seq, shuffle_before_each_epoch=False, rng=None):
"""Iterates over and yields the elements of `iterable` `n_epoch` times.
if `shuffle_before_each_epoch` is True, the elements are put in a list and shuffled before
every pass over the data, including the first."""
if rng is None:
rng = np.random.RandomState()
# create a (shallow) copy so shuffling only applies to the copy.
seq = list(seq)
for i_epoch in itertools.count():
if shuffle_before_each_epoch:
rng.shuffle(seq)
yield from seq
def new_rng(rng):
if rng is not None:
return np.random.RandomState(rng.randint(2 ** 32))
else:
return np.random.RandomState()
def advance_rng(rng, n_generated_ints):
for _ in range(n_generated_ints):
rng.randint(2)
_pool = None
def get_pool(n_workers_if_uninitialized):
global _pool
if _pool is None:
ctx = multiprocessing.get_context('spawn')
# important to use 'spawn', because 'fork' would mean the whole memory is (lazily) copied
# then due to copy-on-write semantics, it gets duplicated when the parent changes anything
_pool = ctx.Pool(n_workers_if_uninitialized, initializer=init_worker_process)
return _pool
def init_worker_process():
os.environ['OMP_NUM_THREADS'] = '1'
terminate_on_parent_death()
signal.signal(signal.SIGINT, signal.SIG_IGN)
def terminate_on_parent_death():
prctl = ctypes.CDLL("libc.so.6").prctl
PR_SET_PDEATHSIG = 1
prctl(PR_SET_PDEATHSIG, signal.SIGTERM)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment