Skip to content

Instantly share code, notes, and snippets.

@Damacustas
Created June 22, 2018 11:44
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Damacustas/7d1351a4cc4155b4ad3daad068362eb3 to your computer and use it in GitHub Desktop.
Save Damacustas/7d1351a4cc4155b4ad3daad068362eb3 to your computer and use it in GitHub Desktop.
Multi GPU keras model
# Taken from https://github.com/keras-team/keras/issues/2436#issuecomment-354882296
# but saved into this gist to make it more easily findable.
from keras import Model
from keras.utils import multi_gpu_model
class ModelMGPU(Model):
def __init__(self, ser_model, gpus):
pmodel = multi_gpu_model(ser_model, gpus)
self.__dict__.update(pmodel.__dict__)
self._smodel = ser_model
def __getattribute__(self, attrname):
'''Override load and save methods to be used from the serial-model. The
serial-model holds references to the weights in the multi-gpu model.
'''
# return Model.__getattribute__(self, attrname)
if 'load' in attrname or 'save' in attrname:
return getattr(self._smodel, attrname)
return super(ModelMGPU, self).__getattribute__(attrname)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment