Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save jakelevi1996/712757aa8bb463f6a9e9274cdfa19592 to your computer and use it in GitHub Desktop.
Save jakelevi1996/712757aa8bb463f6a9e9274cdfa19592 to your computer and use it in GitHub Desktop.
Running TensorFlow Sessions inside sub-processes

Running TensorFlow Sessions inside sub-processes

In some applications of machine learning/TensorFlow, it is desirable to start multiple processes, and have separate training procedures running concurrently in each of those processes. A useful Python method for achieving this is the multiprocessing.pool.Pool.map() method (or the equivalent starmap() method when the target function takes multiple arguments; see the section "Process Pools" from the description of the multiprocessing module in the Python Library Reference).

The Pool.map() method takes a target-function and a list (or more generally an iterable) of arguments, and returns an equivalent iterable of the results of the function evaluated on each member of the argument-list (which is similar to the built-in Python function map(), except that the function-evluations start taking place immediately and concurrently in sub-processes, assuming there are multiple workers in the pool). To give a concrete example, if f(x) returns x**2, and arg_list = [1, 2, 3], then pool.map(f, arg_list) would return [1, 4, 9], as shown in the following code snippet:

from multiprocessing.pool import Pool

def f(x):
    return x**2

if __name__ == "__main__":
    arg_list = [1, 2, 3]
    with Pool() as p:
        print(p.map(f, arg_list))

Returns:

[1, 4, 9]

(Note that it is important to call p.map from inside an if __name__ == "__main__": block, because each process needs to import the __main__ module, and any code outside of such a block will be called during this import; otherwise each sub-process would spawn infinite subprocesses. See more information in the Python Library Reference description of Using a pool of workers).

While it is straightforward to pass class instances as arguments to the map() function, such methods will fail when trying to pass a tf.Session, tf.Tensor, or tf.Operation (or a class-instance containing such an object as a property) as an argument to the map() function, because such objects don't like being pickled, which is necessary when passing arguments to a subprocess. Trying to pass a tf.Session() instance or a class-instance which contains a tf.Tensor as one of its properties as an argument to the map() function throws a TypeError: can't pickle _thread.RLock objects.

So how do you run TensorFlow training sessions concurrently in multiple sub-processes without passing any Session instances or class instances containing a tf.Tensor or tf.Operation as a property? The trick is to use as arguments any parameters which are needed to instantiate such an object, and then instantiate both the object and the tf.Session from inside the subprocess' target function, as shown below:

import tensorflow as tf
from multiprocessing.pool import Pool

class Frog():
    def __init__(self, noise, num_legs=4):
        self.noise = noise
        self.num_legs = num_legs
        self.some_var = tf.Variable(42)

    def make_noise(self, num_times):
        print(self.noise * num_times)


def check_legs(noise, num_legs, num_times):
    frog = Frog(noise, num_legs)
    if frog.num_legs < 4:
        print("Help, I only got {} legs".format(frog.num_legs))
    
    frog.make_noise(num_times)

    with tf.Session() as sess:
        sess.run(tf.initializers.global_variables())
        print(sess.run(frog.some_var))

if __name__ == "__main__":
    num_legs_list = range(2, 6)
    num_times_list = range(1, 5)
    args = [(
        'ribbet', num_legs, num_times
    ) for num_legs, num_times in zip(num_legs_list, num_times_list)]
    with Pool() as p, tf.Session() as sess:
        p.starmap(check_legs, args)

Results (for one particular run):

Warming up the pool...
Assigning workers...
Help, I only got 2 legs
ribbet
Help, I only got 3 legs
ribbetribbet
ribbetribbetribbet
ribbetribbetribbetribbet
42
42
42
42

Note that arg_list contains no tf.Session, tf.Tensor, or tf.Operation objects, nor any class-instances containing such objects as properties; if it did, the Pool.map() function would throw a TypeError: can't pickle _thread.RLock objects.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment