Skip to content

Instantly share code, notes, and snippets.

@YoelShoshan
Created September 27, 2017 06:29
Show Gist options
  • Save YoelShoshan/e8e3a7116e342d9daf6a4c324c015bbb to your computer and use it in GitHub Desktop.
Save YoelShoshan/e8e3a7116e342d9daf6a4c324c015bbb to your computer and use it in GitHub Desktop.
def make_parallel(fn, num_gpus, inputs_to_be_splitted, inputs_to_be_passed_as_is=None):
'''
Converts a function into a parallel version of it, using a simple syncronized data parallelsim
:param fn: a tensorflow function
:param num_gpus: number of gpus
:param inputs_to_be_splitted: a dict with all of the tensorflow variables/placeholders that will be split before being passed to fn
:param inputs_to_be_passed_as_is: a dict with all of the variables that will be passed as is to fn
:return:
'''
if 1==num_gpus:
all_inputs = copy.copy(inputs_to_be_splitted)
if inputs_to_be_passed_as_is is not None:
all_inputs.extend(inputs_to_be_passed_as_is)
return fn(**all_inputs)
in_splits = {}
#for k, v in kwargs.items():
# in_splits[k] = tf.split(v, num_gpus)
for k,v in inputs_to_be_splitted.items():
curr_input_shape = v.shape.as_list() #todo - consider using the dynamic type ?
if curr_input_shape[0] % num_gpus != 0:
raise Exception('param "{}" with 0 axis size = {} cannot be evenly split between the {} gpus.'.format(
k, curr_input_shape[0], num_gpus))
in_splits[k] = tf.split(v, num_gpus)
out_split = None
for i in range(num_gpus):
with tf.device(tf.DeviceSpec(device_type="GPU", device_index=i)):
with tf.variable_scope(tf.get_variable_scope(), reuse=i > 0):
# combine this "tower" inputs, both splitted and the ones that are kept as is, into a single dictionary
all_inputs = {k:v[i] for k,v in in_splits.items()}
if inputs_to_be_passed_as_is is not None:
all_inputs.update(inputs_to_be_passed_as_is)
curr_outputs = fn(**all_inputs)
if out_split is None:
out_split = [[] for _ in range(len(curr_outputs))]
for o_idx,out in enumerate(curr_outputs):
out_split[o_idx].append(out)
concatanated_outputs = [tf.concat(out, axis=0) for out in out_split]
#return tf.concat(out_split, axis=0)
return tuple(concatanated_outputs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment