Created
June 16, 2017 04:28
-
-
Save ypwhs/08b483156228e8f7d7ef186e978a21e3 to your computer and use it in GitHub Desktop.
Keras 多 GPU 同步训练
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.layers.merge import Concatenate | |
from keras.layers.core import Lambda | |
from keras.models import Model | |
import tensorflow as tf | |
def make_parallel(model, gpu_count): | |
def get_slice(data, idx, parts): | |
shape = tf.shape(data) | |
size = tf.concat([ shape[:1] // parts, shape[1:] ],axis=0) | |
stride = tf.concat([ shape[:1] // parts, shape[1:]*0 ],axis=0) | |
start = stride * idx | |
return tf.slice(data, start, size) | |
outputs_all = [] | |
for i in range(len(model.outputs)): | |
outputs_all.append([]) | |
#Place a copy of the model on each GPU, each getting a slice of the batch | |
for i in range(gpu_count): | |
with tf.device('/gpu:%d' % i): | |
with tf.name_scope('tower_%d' % i) as scope: | |
inputs = [] | |
#Slice each input into a piece for processing on this GPU | |
for x in model.inputs: | |
input_shape = tuple(x.get_shape().as_list())[1:] | |
slice_n = Lambda(get_slice, output_shape=input_shape, arguments={'idx':i,'parts':gpu_count})(x) | |
inputs.append(slice_n) | |
outputs = model(inputs) | |
if not isinstance(outputs, list): | |
outputs = [outputs] | |
#Save all the outputs for merging back together later | |
for l in range(len(outputs)): | |
outputs_all[l].append(outputs[l]) | |
# merge outputs on CPU | |
with tf.device('/cpu:0'): | |
merged = [] | |
for outputs in outputs_all: | |
merged.append(Concatenate(axis=0)(outputs)) | |
return Model(model.inputs, merged) |
你在保存模型的时候需要保存单个模型,然后载入的时候再 make_parallel,这样就不会有问题了。 @ChaoXiWhite
举个例子:
训练:
model = Model(...)
model_parallel = make_parallel(model, 4)
model_parallel.fit(...)
model.save('model.h5')
载入+预测:
model = load_model('model.h5')
model_parallel = make_parallel(model, 4)
model_parallel.predict(...)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
There has an error 'NameError: global name 'tf' is not defined' on line 9 when I am loading model from h5 files. Please tell me how to solve this problem. I had changed 'model=keras.models.load_model('model.hdf5', custom_objects={"tf": tf})' and it worked, but it only used one gpu. Thank you very much.