Skip to content

Instantly share code, notes, and snippets.

@priancho
Last active March 22, 2019 10:25
Show Gist options
  • Save priancho/c1b5facdd479d282284c513f8b946d51 to your computer and use it in GitHub Desktop.
Save priancho/c1b5facdd479d282284c513f8b946d51 to your computer and use it in GitHub Desktop.
Keras ModelCheckpoint callback for multi_gpu_model
class MultiGPUModelCheckpoint(ModelCheckpoint):
"""MultiGPU model checkpointing.
Save a template model, not multi_gpu_model instance.
Save the model at the beginning (a random init model) and at the
end of the training.
> https://github.com/keras-team/keras/issues/8463
"""
def __init__(self, *args, **kwargs):
super(MultiGPUModelCheckpoint, self).__init__(*args, **kwargs)
self.current_epoch = 0
def detachmodel(self):
""" Detach model trained on GPUs from its encapsulation
Args:
:param m: obj, keras model
Return:
:return: obj, keras model
"""
for l in self.model.layers:
if l.name == 'model_1':
return l
return self.model
def on_epoch_begin(self, epoch, logs=None):
self.current_epoch = epoch
def on_train_begin(self, logs=None):
"""Save the initialized model."""
logs = logs or {}
epoch = self.current_epoch
filepath = self.filepath.format(epoch=epoch, **logs)
if self.verbose > 0:
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
if self.save_weights_only:
self.detachmodel().save_weights(filepath, overwrite=True)
else:
self.detachmodel().save(filepath, overwrite=True)
def on_train_end(self, logs=None):
"""Save the last model."""
logs = logs or {}
epoch = self.current_epoch # maximum epochs
filepath = self.filepath.format(epoch=epoch+1, **logs)
if self.verbose > 0:
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
if self.save_weights_only:
self.detachmodel().save_weights(filepath, overwrite=True)
else:
self.detachmodel().save(filepath, overwrite=True)
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epochs_since_last_save += 1
if self.epochs_since_last_save >= self.period:
self.epochs_since_last_save = 0
filepath = self.filepath.format(epoch=epoch + 1, **logs)
if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
warnings.warn('Can save best model only with %s available, '
'skipping.' % (self.monitor), RuntimeWarning)
else:
if self.monitor_op(current, self.best):
if self.verbose > 0:
print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
' saving model to %s'
% (epoch + 1, self.monitor, self.best,
current, filepath))
self.best = current
if self.save_weights_only:
self.detachmodel().save_weights(filepath, overwrite=True)
else:
self.detachmodel().save(filepath, overwrite=True)
else:
if self.verbose > 0:
print('\nEpoch %05d: %s did not improve from %0.5f' %
(epoch + 1, self.monitor, self.best))
else:
if self.verbose > 0:
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
if self.save_weights_only:
self.detachmodel().save_weights(filepath, overwrite=True)
else:
self.detachmodel().save(filepath, overwrite=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment