Last active
March 22, 2019 10:25
-
-
Save priancho/c1b5facdd479d282284c513f8b946d51 to your computer and use it in GitHub Desktop.
Keras ModelCheckpoint callback for multi_gpu_model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MultiGPUModelCheckpoint(ModelCheckpoint): | |
"""MultiGPU model checkpointing. | |
Save a template model, not multi_gpu_model instance. | |
Save the model at the beginning (a random init model) and at the | |
end of the training. | |
> https://github.com/keras-team/keras/issues/8463 | |
""" | |
def __init__(self, *args, **kwargs): | |
super(MultiGPUModelCheckpoint, self).__init__(*args, **kwargs) | |
self.current_epoch = 0 | |
def detachmodel(self): | |
""" Detach model trained on GPUs from its encapsulation | |
Args: | |
:param m: obj, keras model | |
Return: | |
:return: obj, keras model | |
""" | |
for l in self.model.layers: | |
if l.name == 'model_1': | |
return l | |
return self.model | |
def on_epoch_begin(self, epoch, logs=None): | |
self.current_epoch = epoch | |
def on_train_begin(self, logs=None): | |
"""Save the initialized model.""" | |
logs = logs or {} | |
epoch = self.current_epoch | |
filepath = self.filepath.format(epoch=epoch, **logs) | |
if self.verbose > 0: | |
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath)) | |
if self.save_weights_only: | |
self.detachmodel().save_weights(filepath, overwrite=True) | |
else: | |
self.detachmodel().save(filepath, overwrite=True) | |
def on_train_end(self, logs=None): | |
"""Save the last model.""" | |
logs = logs or {} | |
epoch = self.current_epoch # maximum epochs | |
filepath = self.filepath.format(epoch=epoch+1, **logs) | |
if self.verbose > 0: | |
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath)) | |
if self.save_weights_only: | |
self.detachmodel().save_weights(filepath, overwrite=True) | |
else: | |
self.detachmodel().save(filepath, overwrite=True) | |
def on_epoch_end(self, epoch, logs=None): | |
logs = logs or {} | |
self.epochs_since_last_save += 1 | |
if self.epochs_since_last_save >= self.period: | |
self.epochs_since_last_save = 0 | |
filepath = self.filepath.format(epoch=epoch + 1, **logs) | |
if self.save_best_only: | |
current = logs.get(self.monitor) | |
if current is None: | |
warnings.warn('Can save best model only with %s available, ' | |
'skipping.' % (self.monitor), RuntimeWarning) | |
else: | |
if self.monitor_op(current, self.best): | |
if self.verbose > 0: | |
print('\nEpoch %05d: %s improved from %0.5f to %0.5f,' | |
' saving model to %s' | |
% (epoch + 1, self.monitor, self.best, | |
current, filepath)) | |
self.best = current | |
if self.save_weights_only: | |
self.detachmodel().save_weights(filepath, overwrite=True) | |
else: | |
self.detachmodel().save(filepath, overwrite=True) | |
else: | |
if self.verbose > 0: | |
print('\nEpoch %05d: %s did not improve from %0.5f' % | |
(epoch + 1, self.monitor, self.best)) | |
else: | |
if self.verbose > 0: | |
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath)) | |
if self.save_weights_only: | |
self.detachmodel().save_weights(filepath, overwrite=True) | |
else: | |
self.detachmodel().save(filepath, overwrite=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment