Skip to content

Instantly share code, notes, and snippets.

@nicjac
Created December 27, 2021 21:12
Show Gist options
  • Save nicjac/b363d2454ea253570a54e5e178e7666a to your computer and use it in GitHub Desktop.
Save nicjac/b363d2454ea253570a54e5e178e7666a to your computer and use it in GitHub Desktop.
An updated SaveModelCallback for fastai that also saves metrics tracked by the recorder
class SaveModelCallback(TrackerCallback):
"A `TrackerCallback` that saves the model's best during training and loads it at the end."
order = TrackerCallback.order+1
def __init__(self, monitor='valid_loss', comp=None, min_delta=0., fname='model', every_epoch=False, at_end=False,
with_opt=False, reset_on_fit=True):
super().__init__(monitor=monitor, comp=comp, min_delta=min_delta, reset_on_fit=reset_on_fit)
assert not (every_epoch and at_end), "every_epoch and at_end cannot both be set to True"
# keep track of file path for loggers
self.last_saved_path = None
self.last_saved_metadata = None
store_attr('fname,every_epoch,at_end,with_opt')
def _save(self, name, metadata):
self.last_saved_path = self.learn.save(name, with_opt=self.with_opt)
self.last_saved_metadata = metadata
def after_epoch(self):
"Compare the value monitored to its best score and save if best."
if self.every_epoch:
if (self.epoch%self.every_epoch) == 0: self._save(f'{self.fname}_{self.epoch}')
else: #every improvement
super().after_epoch()
if self.new_best:
print(f'Better model found at epoch {self.epoch} with {self.monitor} value: {self.best}.')
self._save(f'{self.fname}', {n:s for n,s in zip(self.recorder.metric_names, self.recorder.log)})
def after_fit(self, **kwargs):
"Load the best model."
if self.at_end: self._save(f'{self.fname}')
elif not self.every_epoch: self.learn.load(f'{self.fname}', with_opt=self.with_opt)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment