Skip to content

Instantly share code, notes, and snippets.

@priancho
Last active March 22, 2019 06:22
Show Gist options
  • Save priancho/63a2ab3072862247ace30e179669ff79 to your computer and use it in GitHub Desktop.
Save priancho/63a2ab3072862247ace30e179669ff79 to your computer and use it in GitHub Desktop.
Save the model at the beginning/end of training
class CustomModelCheckpoint(ModelCheckpoint):
"""Custom ModelCheckpoint.
Save the model at the beginning (a random init model) and at the
end of the training.
"""
def __init__(self, *args, **kwargs):
super(CustomModelCheckpoint, self).__init__(*args, **kwargs)
self.current_epoch = 0
def on_epoch_begin(self, epoch, logs=None):
self.current_epoch = epoch
def on_train_begin(self, logs=None):
"""Save the initialized model."""
logs = logs or {}
epoch = self.current_epoch
filepath = self.filepath.format(epoch=epoch, **logs)
if self.verbose > 0:
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)
def on_train_end(self, logs=None):
"""Save the last model."""
logs = logs or {}
epoch = self.current_epoch # maximum epochs
filepath = self.filepath.format(epoch=epoch+1, **logs)
if self.verbose > 0:
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)
@priancho
Copy link
Author

Save the model at the beginning and the END of training.
No need to save the model manually after the training :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment