Created
June 19, 2019 14:05
-
-
Save sailist/b822dab2a626064ac29694736ffd21d7 to your computer and use it in GitHub Desktop.
使用多张GPU同时训练,存储自https://www.jianshu.com/p/db0ba022936f
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ParallelModel(keras.models.Model): | |
"""Subclasses the standard Keras Model and adds multi-GPU support. | |
It works by creating a copy of the model on each GPU. Then it slices | |
the inputs and sends a slice to each copy of the model, and then | |
merges the outputs together and applies the loss on the combined | |
outputs. | |
""" | |
def __init__(self, keras_model, gpu_count): | |
"""Class constructor. | |
keras_model: The Keras model to parallelize | |
gpu_count: Number of GPUs. Must be > 1 | |
""" | |
super(ParallelModel, self).__init__() # Thanks to @greatken999 for fixing bugs | |
self.inner_model = keras_model | |
self.gpu_count = gpu_count | |
merged_outputs = self.make_parallel() | |
super(ParallelModel, self).__init__(inputs=self.inner_model.inputs, | |
outputs=merged_outputs) | |
def __getattribute__(self, attrname): | |
"""Redirect loading and saving methods to the inner model. That's where | |
the weights are stored.""" | |
if 'load' in attrname or 'save' in attrname: | |
return getattr(self.inner_model, attrname) | |
return super(ParallelModel, self).__getattribute__(attrname) | |
def summary(self, *args, **kwargs): | |
"""Override summary() to display summaries of both, the wrapper | |
and inner models.""" | |
super(ParallelModel, self).summary(*args, **kwargs) | |
self.inner_model.summary(*args, **kwargs) | |
def make_parallel(self): | |
"""Creates a new wrapper model that consists of multiple replicas of | |
the original model placed on different GPUs. | |
""" | |
# Slice inputs. Slice inputs on the CPU to avoid sending a copy | |
# of the full inputs to all GPUs. Saves on bandwidth and memory. | |
input_slices = {name: tf.split(x, self.gpu_count) | |
for name, x in zip(self.inner_model.input_names, | |
self.inner_model.inputs)} | |
output_names = self.inner_model.output_names | |
outputs_all = [] | |
for i in range(len(self.inner_model.outputs)): | |
outputs_all.append([]) | |
# Run the model call() on each GPU to place the ops there | |
for i in range(self.gpu_count): | |
with tf.device('/gpu:%d' % i): | |
with tf.name_scope('tower_%d' % i): | |
# Run a slice of inputs through this replica | |
zipped_inputs = zip(self.inner_model.input_names, | |
self.inner_model.inputs) | |
inputs = [ | |
KL.Lambda(lambda s: input_slices[name][i], | |
output_shape=lambda s: (None,) + s[1:])(tensor) | |
for name, tensor in zipped_inputs] | |
# Create the model replica and get the outputs | |
outputs = self.inner_model(inputs) | |
if not isinstance(outputs, list): | |
outputs = [outputs] | |
# Save the outputs for merging back together later | |
for l, o in enumerate(outputs): | |
outputs_all[l].append(o) | |
# Merge outputs on CPU | |
with tf.device('/cpu:0'): | |
merged = [] | |
for outputs, name in zip(outputs_all, output_names): | |
# If outputs are numbers without dimensions, add a batch dim. | |
def add_dim(tensor): | |
"""Add a dimension to tensors that don't have any.""" | |
if K.int_shape(tensor) == (): | |
return KL.Lambda(lambda t: K.reshape(t, [1, 1]))(tensor) | |
return tensor | |
outputs = list(map(add_dim, outputs)) | |
# Concatenate | |
merged.append(KL.Concatenate(axis=0, name=name)(outputs)) | |
return merged | |
#使用方法 | |
GPU_COUNT = 3 # 同时使用3个GPU | |
model = keras.applications.densenet.DenseNet201() # 比如使用DenseNet-201 | |
model = ParallelModel(model, GPU_COUNT) | |
model.compile(optimizer=Adam(lr=1e-5), loss='binary_crossentropy', metrics = ['accuracy']) | |
model.fit(X_train, y_train, | |
batch_size=batch_size*GPU_COUNT, | |
epochs=nb_epoch, verbose=0, shuffle=True, | |
validation_data=(X_valid, y_valid)) | |
model.save_weights('/path/to/save/model.h5') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment