Created
September 27, 2017 06:29
-
-
Save YoelShoshan/e8e3a7116e342d9daf6a4c324c015bbb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def make_parallel(fn, num_gpus, inputs_to_be_splitted, inputs_to_be_passed_as_is=None): | |
''' | |
Converts a function into a parallel version of it, using a simple syncronized data parallelsim | |
:param fn: a tensorflow function | |
:param num_gpus: number of gpus | |
:param inputs_to_be_splitted: a dict with all of the tensorflow variables/placeholders that will be split before being passed to fn | |
:param inputs_to_be_passed_as_is: a dict with all of the variables that will be passed as is to fn | |
:return: | |
''' | |
if 1==num_gpus: | |
all_inputs = copy.copy(inputs_to_be_splitted) | |
if inputs_to_be_passed_as_is is not None: | |
all_inputs.extend(inputs_to_be_passed_as_is) | |
return fn(**all_inputs) | |
in_splits = {} | |
#for k, v in kwargs.items(): | |
# in_splits[k] = tf.split(v, num_gpus) | |
for k,v in inputs_to_be_splitted.items(): | |
curr_input_shape = v.shape.as_list() #todo - consider using the dynamic type ? | |
if curr_input_shape[0] % num_gpus != 0: | |
raise Exception('param "{}" with 0 axis size = {} cannot be evenly split between the {} gpus.'.format( | |
k, curr_input_shape[0], num_gpus)) | |
in_splits[k] = tf.split(v, num_gpus) | |
out_split = None | |
for i in range(num_gpus): | |
with tf.device(tf.DeviceSpec(device_type="GPU", device_index=i)): | |
with tf.variable_scope(tf.get_variable_scope(), reuse=i > 0): | |
# combine this "tower" inputs, both splitted and the ones that are kept as is, into a single dictionary | |
all_inputs = {k:v[i] for k,v in in_splits.items()} | |
if inputs_to_be_passed_as_is is not None: | |
all_inputs.update(inputs_to_be_passed_as_is) | |
curr_outputs = fn(**all_inputs) | |
if out_split is None: | |
out_split = [[] for _ in range(len(curr_outputs))] | |
for o_idx,out in enumerate(curr_outputs): | |
out_split[o_idx].append(out) | |
concatanated_outputs = [tf.concat(out, axis=0) for out in out_split] | |
#return tf.concat(out_split, axis=0) | |
return tuple(concatanated_outputs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment