Skip to content

Instantly share code, notes, and snippets.

@skywalkerisnull
Created July 1, 2019 04:47
Show Gist options
  • Save skywalkerisnull/cebc1fc2b00fa76da92173d2baa21714 to your computer and use it in GitHub Desktop.
Save skywalkerisnull/cebc1fc2b00fa76da92173d2baa21714 to your computer and use it in GitHub Desktop.
Be able to use the multi-gpu on Keras 2.2.4
"""
Mask R-CNN
Multi-GPU Support for Keras.
Copyright (c) 2017 Matterport, Inc.
Licensed under the MIT License (see LICENSE for details)
Written by Waleed Abdulla
Ideas and a small code snippets from these sources:
https://github.com/fchollet/keras/issues/2436
https://medium.com/@kuza55/transparent-multi-gpu-training-on-tensorflow-with-keras-8b0016fd9012
https://github.com/avolkov1/keras_experiments/blob/master/keras_exp/multigpu/
https://github.com/fchollet/keras/blob/master/keras/utils/training_utils.py
"""
import tensorflow as tf
import keras.backend as K
import keras.layers as KL
import keras.models as KM
class ParallelModel(KM.Model):
"""Subclasses the standard Keras Model and adds multi-GPU support.
It works by creating a copy of the model on each GPU. Then it slices
the inputs and sends a slice to each copy of the model, and then
merges the outputs together and applies the loss on the combined
outputs.
"""
def __init__(self, keras_model, gpu_count):
"""Class constructor.
keras_model: The Keras model to parallelize
gpu_count: Number of GPUs. Must be > 1
"""
self.inner_model = keras_model
self.gpu_count = gpu_count
merged_outputs = self.make_parallel()
super(ParallelModel, self).__init__(inputs=self.inner_model.inputs,
outputs=merged_outputs)
def __getattribute__(self, attrname):
"""Redirect loading and saving methods to the inner model. That's where
the weights are stored."""
if 'load' in attrname or 'save' in attrname:
return getattr(self.inner_model, attrname)
return super(ParallelModel, self).__getattribute__(attrname)
def summary(self, *args, **kwargs):
"""Override summary() to display summaries of both, the wrapper
and inner models."""
super(ParallelModel, self).summary(*args, **kwargs)
self.inner_model.summary(*args, **kwargs)
def make_parallel(self):
"""Creates a new wrapper model that consists of multiple replicas of
the original model placed on different GPUs.
"""
# Slice inputs. Slice inputs on the CPU to avoid sending a copy
# of the full inputs to all GPUs. Saves on bandwidth and memory.
input_slices = {name: tf.split(x, self.gpu_count)
for name, x in zip(self.inner_model.input_names,
self.inner_model.inputs)}
output_names = self.inner_model.output_names
outputs_all = []
for i in range(len(self.inner_model.outputs)):
outputs_all.append([])
# Run the model call() on each GPU to place the ops there
for i in range(self.gpu_count):
with tf.device('/gpu:%d' % i):
with tf.name_scope('tower_%d' % i):
# Run a slice of inputs through this replica
zipped_inputs = zip(self.inner_model.input_names,
self.inner_model.inputs)
inputs = [
KL.Lambda(lambda s: input_slices[name][i],
output_shape=lambda s: (None,) + s[1:])(tensor)
for name, tensor in zipped_inputs]
# Create the model replica and get the outputs
outputs = self.inner_model(inputs)
if not isinstance(outputs, list):
outputs = [outputs]
# Save the outputs for merging back together later
for l, o in enumerate(outputs):
outputs_all[l].append(o)
# Merge outputs on CPU
with tf.device('/cpu:0'):
merged = []
for outputs, name in zip(outputs_all, output_names):
# Concatenate or average outputs?
# Outputs usually have a batch dimension and we concatenate
# across it. If they don't, then the output is likely a loss
# or a metric value that gets averaged across the batch.
# Keras expects losses and metrics to be scalars.
if K.int_shape(outputs[0]) == ():
# Average
m = KL.Lambda(lambda o: tf.add_n(o) / len(outputs), name=name)(outputs)
else:
# Concatenate
m = KL.Concatenate(axis=0, name=name)(outputs)
merged.append(m)
return merged
if __name__ == "__main__":
# Testing code below. It creates a simple model to train on MNIST and
# tries to run it on 2 GPUs. It saves the graph so it can be viewed
# in TensorBoard. Run it as:
#
# python3 parallel_model.py
import os
import numpy as np
import keras.optimizers
from keras.datasets import mnist
from keras.preprocessing.image import ImageDataGenerator
GPU_COUNT = 2
# Root directory of the project
ROOT_DIR = os.path.abspath("../")
# Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs")
def build_model(x_train, num_classes):
# Reset default graph. Keras leaves old ops in the graph,
# which are ignored for execution but clutter graph
# visualization in TensorBoard.
tf.reset_default_graph()
inputs = KL.Input(shape=x_train.shape[1:], name="input_image")
x = KL.Conv2D(32, (3, 3), activation='relu', padding="same",
name="conv1")(inputs)
x = KL.Conv2D(64, (3, 3), activation='relu', padding="same",
name="conv2")(x)
x = KL.MaxPooling2D(pool_size=(2, 2), name="pool1")(x)
x = KL.Flatten(name="flat1")(x)
x = KL.Dense(128, activation='relu', name="dense1")(x)
x = KL.Dense(num_classes, activation='softmax', name="dense2")(x)
return KM.Model(inputs, x, "digit_classifier_model")
# Load MNIST Data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = np.expand_dims(x_train, -1).astype('float32') / 255
x_test = np.expand_dims(x_test, -1).astype('float32') / 255
print('x_train shape:', x_train.shape)
print('x_test shape:', x_test.shape)
# Build data generator and model
datagen = ImageDataGenerator()
model = build_model(x_train, 10)
# Add multi-GPU support.
model = ParallelModel(model, GPU_COUNT)
optimizer = keras.optimizers.SGD(lr=0.01, momentum=0.9, clipnorm=5.0)
model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer, metrics=['accuracy'])
model.summary()
# Train
model.fit_generator(
datagen.flow(x_train, y_train, batch_size=64),
steps_per_epoch=50, epochs=10, verbose=1,
validation_data=(x_test, y_test),
callbacks=[keras.callbacks.TensorBoard(log_dir=MODEL_DIR,
write_graph=True)]
)
@BenoCharlo
Copy link

Hello, I've implemented this function to use multi-gpu with Keras 2.2.4 But there is still a trouble when i tried to train the model. I got an error:

MaybeEncodingError: Error sending result: '([array([[[[-123.7, -116.8, -103.9],

I don't really know how to fix this. Have you an idea? Thks

@skywalkerisnull
Copy link
Author

Hello, I've implemented this function to use multi-gpu with Keras 2.2.4 But there is still a trouble when i tried to train the model. I got an error:

MaybeEncodingError: Error sending result: '([array([[[[-123.7, -116.8, -103.9],

I don't really know how to fix this. Have you an idea? Thks

Hey Beno,

Sorry, I haven't seen that error before and am not sure where to look to trouble shoot that one.

@BenoCharlo
Copy link

What is the main change you bring to the parallel_model compared to the one we have in MaskRCNN implementation?
Has that code work for you with Keras 2.2.4?
If so, what are the specs in your config.py and model.py?
GPU >1 ? workers>1, use_multiprocessing =True?
Thank you

@BenoCharlo
Copy link

What is the main change you bring to the parallel_model compared to the one we have in MaskRCNN implementation?
Has that code work for you with Keras 2.2.4?
If so, what are the specs in your config.py and model.py?
GPU >1 ? workers>1, use_multiprocessing =True?
Thank you

The end of the error is Reason: 'error("'i' format requires -2147483648 <= number <= 2147483647",)'

Also add in def init, super(ParallelModel, self).init()

This seems to be odd to me. I'm using python 3.5.2.

@skywalkerisnull
Copy link
Author

From memory, I am using Python 3.6
When I am back in the office tomorrow I will do a dump out of conda and pip for the versions of the packages that I am using.

@BenoCharlo
Copy link

Ok cool.

I've upgraded python to Python 3.6, but same errors. I'm wondering if it came from model.py or config. py.
You confirm the code work with Keras 2.2.4
Anyway, thank you for your help.

@BenoCharlo
Copy link

Great it's working now. I have upgraded Python to Python 3.6 with Keras 2.2.4 and tensorflow and tensorflow-gpu packages. I have also upgraded the size of the cluster i'm using (it was the key element).
Thank you for your time

@skywalkerisnull
Copy link
Author

No worries,

But for anyone else that may come across this issue in the future, I was using Python 3.6.8 and this is the Conda and Pip dump:

Conda:

_tflow_select             2.1.0                       gpu  
absl-py                   0.7.1                    py36_0  
affine                    2.2.2                    py36_0  
astor                     0.7.1                    py36_0  
attrs                     19.1.0                   py36_1  
blas                      1.0                         mkl  
bzip2                     1.0.8                he774522_0  
ca-certificates           2019.8.28                     0  
certifi                   2019.9.11                py36_0  
chardet                   3.0.4                    pypi_0    pypi
click                     7.0                      pypi_0    pypi
click-plugins             1.1.1                      py_0  
cligj                     0.5.0                    py36_0  
cloudpickle               1.2.1                      py_0    conda-forge  
cudatoolkit               10.0.130                      0  
cudnn                     7.6.0                cuda10.0_0  
curl                      7.65.2               h2a8f88b_0  
cycler                    0.10.0                     py_1    conda-forge  
cytoolz                   0.9.0.1         py36hfa6e2cd_1001    conda-forge
dask-core                 1.2.2                      py_0    conda-forge  
decorator                 4.4.0                      py_0    conda-forge  
expat                     2.2.5                he025d50_0  
flask                     1.0.3                    pypi_0    pypi
flask-cors                3.0.7                    pypi_0    pypi
freetype                  2.10.0               h5db478b_0    conda-forge  
freexl                    1.0.5                hfa6e2cd_0  
gast                      0.2.2                    py36_0  
geojson                   2.5.0                    pypi_0    pypi
geos                      3.7.1             he025d50_1000    conda-forge  
grpcio                    1.16.1           py36h351948d_1  
h5py                      2.9.0            py36h5e291fa_0  
hdf4                      4.2.13               h712560f_2  
hdf5                      1.10.4               h7ebc959_0  
icc_rt                    2019.0.0             h0cc432a_1  
icu                       58.2                 ha66f8fd_1  
idna                      2.8                      pypi_0    pypi
imageio                   2.5.0                    py36_0    conda-forge  
imantics                  0.1.10                   pypi_0    pypi
imgaug                    0.2.9                      py_0    conda-forge  
intel-openmp              2019.4                      245  
intel-tensorflow          0.0.1                    pypi_0    pypi
itsdangerous              1.1.0                    pypi_0    pypi
jinja2                    2.10.1                   pypi_0    pypi
joblib                    0.13.2                   pypi_0    pypi
jpeg                      9c                hfa6e2cd_1001    conda-forge  
kealib                    1.4.7                h07cbb95_6  
keras-applications        1.0.7                      py_0  
keras-base                2.2.4                    py36_0  
keras-gpu                 2.2.4                         0  
keras-preprocessing       1.0.9                      py_0  
keras2onnx                1.5.0                    pypi_0    pypi
kiwisolver                1.1.0            py36he980bc4_0    conda-forge  
krb5                      1.16.1               hc04afaa_7  
libblas                   3.8.0                     8_mkl    conda-forge  
libboost                  1.67.0               hd9e427e_4  
libcblas                  3.8.0                     8_mkl    conda-forge  
libcurl                   7.65.2               h2a8f88b_0
libgdal                   2.3.3                h10f50ba_0
libiconv                  1.15                 h1df5818_7
libkml                    1.3.0                he5f2a48_4
liblapack                 3.8.0                     8_mkl    conda-forge
liblapacke                3.8.0                     8_mkl    conda-forge
libnetcdf                 4.6.1                h411e497_2
libpng                    1.6.37               h7602738_0    conda-forge
libpq                     11.2                 h3235a2c_0
libprotobuf               3.7.1                h7bd577a_0
libspatialite             4.3.0a              hc36aec2_19
libssh2                   1.8.2                h7a1dbc1_0
libtiff                   4.0.10            h6512ee2_1003    conda-forge
libwebp                   1.0.2                hfa6e2cd_2    conda-forge
libxml2                   2.9.9                h464c3ec_0
lxml                      4.4.1                    pypi_0    pypi
lz4-c                     1.8.3             he025d50_1001    conda-forge
markdown                  3.1                      py36_0
markupsafe                1.1.1                    pypi_0    pypi
matplotlib                3.1.0                    py36_1    conda-forge
matplotlib-base           3.1.0            py36h2852a4a_1    conda-forge
mkl                       2019.4                      245
mkl_fft                   1.0.12           py36h14836fe_0
mkl_random                1.0.2            py36h343c172_0
mock                      3.0.5                    py36_0
networkx                  2.3                        py_0    conda-forge
numpy                     1.16.4           py36h19fb1c0_0
numpy-base                1.16.4           py36hc3f5095_0
olefile                   0.46                       py_0    conda-forge
onnx                      1.5.0                    pypi_0    pypi
onnxconverter-common      1.5.0                    pypi_0    pypi
onnxmltools               1.4.1                    pypi_0    pypi
opencv                    4.1.0            py36hb4945ee_5    conda-forge
opencv-python             4.1.0.25                 pypi_0    pypi
openssl                   1.1.1d               he774522_0
pandas                    0.25.0                   pypi_0    pypi
pcre                      8.43                 ha925a31_0
pillow                    6.0.0                    pypi_0    pypi
pip                       19.1.1                   py36_0
proj4                     5.2.0                ha925a31_1
protobuf                  3.7.1            py36h33f27b4_0
pyparsing                 2.4.0                      py_0    conda-forge
pyqt                      5.9.2            py36h6538335_0    conda-forge
pyreadline                2.1                      py36_1
python                    3.6.8                h9f7ef89_7
python-dateutil           2.8.0                      py_0    conda-forge
pytz                      2019.2                   pypi_0    pypi
pywavelets                1.0.3            py36h452e1ab_1    conda-forge
pyyaml                    5.1              py36he774522_0
qt                        5.9.7                hc6833c9_1    conda-forge
rasterio                  1.0.21           py36h6bd7d87_0
requests                  2.22.0                   pypi_0    pypi
rope                      0.14.0                   pypi_0    pypi
scikit-image              0.15.0                   pypi_0    pypi
scikit-learn              0.21.2                   pypi_0    pypi
scipy                     1.2.1            py36h29ff71c_0
setuptools                41.0.1                   py36_0
shapely                   1.6.4           py36h8921fb9_1004    conda-forge
sip                       4.19.8          py36h6538335_1000    conda-forge
six                       1.12.0                   py36_0
skl2onnx                  1.4.9                    pypi_0    pypi
snuggs                    1.4.6                      py_0
sqlite                    3.28.0               he774522_0
tenacity                  5.0.4                    pypi_0    pypi
tensorboard               1.13.1           py36h33f27b4_0
tensorflow                1.13.1          gpu_py36h9006a92_0
tensorflow-base           1.13.1          gpu_py36h871c8ca_0
tensorflow-estimator      1.13.0                     py_0
tensorflow-gpu            1.13.1               h0d30ee6_0
tensorflow-serving-api    1.13.0                   pypi_0    pypi
termcolor                 1.1.0                    py36_1
tk                        8.6.9             hfa6e2cd_1002    conda-forge
toolz                     0.9.0                      py_1    conda-forge
tornado                   6.0.2            py36hfa6e2cd_0    conda-forge
typing                    3.6.6                    pypi_0    pypi
typing-extensions         3.7.2                    pypi_0    pypi
urllib3                   1.25.3                   pypi_0    pypi
vc                        14.1                 h0510ff6_4
vs2015_runtime            14.16.27012          hf0eaf9b_0
werkzeug                  0.15.2                     py_0
wheel                     0.33.4                   py36_0
wincertstore              0.2              py36h7fe50ca_0
xerces-c                  3.2.2                ha925a31_0
xz                        5.2.4             h2fa13f4_1001    conda-forge
yaml                      0.1.7                hc54c509_2
zlib                      1.2.11               h62dcd97_3
zstd                      1.4.0                hd8a0e53_0    conda-forge

pip: 

Package                Version    
---------------------- -----------
absl-py                0.7.1
affine                 2.2.2
astor                  0.7.1      
attrs                  19.1.0
certifi                2019.9.11
chardet                3.0.4
Click                  7.0
click-plugins          1.1.1
cligj                  0.5.0
cloudpickle            1.2.1
cycler                 0.10.0
cytoolz                0.9.0.1
dask                   1.2.2
decorator              4.4.0
Flask                  1.0.3
Flask-Cors             3.0.7
gast                   0.2.2
geojson                2.5.0
grpcio                 1.16.1
h5py                   2.9.0
idna                   2.8
imageio                2.5.0
imantics               0.1.10
imgaug                 0.2.9
intel-tensorflow       0.0.1
itsdangerous           1.1.0
Jinja2                 2.10.1
joblib                 0.13.2
Keras                  2.2.4
Keras-Applications     1.0.7
Keras-Preprocessing    1.0.9
keras2onnx             1.5.0
kiwisolver             1.1.0
lxml                   4.4.1
Markdown               3.1
MarkupSafe             1.1.1
matplotlib             3.1.0
mkl-fft                1.0.12
mkl-random             1.0.2
mock                   3.0.5
networkx               2.3
numpy                  1.16.4
olefile                0.46
onnx                   1.5.0
onnxconverter-common   1.5.0
onnxmltools            1.4.1
opencv-python          4.1.0.25
pandas                 0.25.0
Pillow                 6.0.0
pip                    19.1.1
protobuf               3.7.1
pyparsing              2.4.0
pyreadline             2.1
python-dateutil        2.8.0
pytz                   2019.2
PyWavelets             1.0.3
PyYAML                 5.1
rasterio               1.0.21
requests               2.22.0
rope                   0.14.0
scikit-image           0.15.0
scikit-learn           0.21.2
scipy                  1.2.1
setuptools             41.0.1
Shapely                1.6.4.post2
six                    1.12.0
skl2onnx               1.4.9
snuggs                 1.4.6
tenacity               5.0.4
tensorboard            1.13.1
tensorflow             1.13.1
tensorflow-estimator   1.13.0
tensorflow-serving-api 1.13.0
termcolor              1.1.0
toolz                  0.9.0
tornado                6.0.2
typing                 3.6.6
typing-extensions      3.7.2
urllib3                1.25.3
Werkzeug               0.15.2
wheel                  0.33.4
wincertstore           0.2

@zcunyi
Copy link

zcunyi commented Apr 19, 2021

Hello, I've implemented this function to use multi-gpu with Keras 2.2.4 But there is still a trouble when i tried to train the model. I got an error:

AttributeError: 'Model' object has no attribute 'input_names'

I don't really know how to fix this. Have you an idea? Thks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment