Created
October 11, 2016 06:46
-
-
Save cerisara/d43e9374a3d2eb9d44487606e0e29966 to your computer and use it in GitHub Desktop.
multigpu.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Data Parallelization with multi-GPU over TensorFlow | |
Jonathan Laserson <jonilaserson@gmail.com> | |
9 oct. (Il y a 2 jours) | |
À Keras-users Se désabonner | |
Here is how to take an existing model and do data parallelization across multiple GPUs. | |
The assumption here is that model receives a single input (but it should be easily adjustable for more). | |
To use just take any model and set model = to_multi_gpu(model). | |
model.fit and model.predict should work without any change. | |
import tensorflow as tf | |
from keras import backend as K | |
from keras.models import Model | |
from keras.layers import Input, merge | |
from keras.layers.core import Lambda | |
def slice_batch(x, n_gpus, part): | |
sh = K.shape(x) | |
L = sh[0] / n_gpus | |
if part == n_gpus - 1: | |
return x[part*L:] | |
return x[part*L:(part+1)*L] | |
def to_multi_gpu(model, n_gpus=2): | |
with tf.device('/cpu:0'): | |
x = Input(model.input_shape[1:], name=model.input_names[0]) | |
towers = [] | |
for g in range(n_gpus): | |
with tf.device('/gpu:' + str(g)): | |
slice_g = Lambda(slice_batch, lambda shape: shape, arguments={'n_gpus':n_gpus, 'part':g})(x) | |
towers.append(model(slice_g)) | |
with tf.device('/cpu:0'): | |
merged = merge(towers, mode='concat', concat_axis=0) | |
return Model(input=[x], output=merged) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment