Skip to content

Instantly share code, notes, and snippets.

@nuqz
Created September 17, 2018 13:12
Show Gist options
  • Save nuqz/0f25030d2783af23eb3fc11a938ef030 to your computer and use it in GitHub Desktop.
Save nuqz/0f25030d2783af23eb3fc11a938ef030 to your computer and use it in GitHub Desktop.
from keras.callbacks import Callback
class PatientTrainingCallback(Callback):
def __init__(self):
Callback.__init__(self)
self.episode_finished = True
self.end_train_before_episode_end = False
def on_train_begin(self, logs=None):
pass
def on_epoch_begin(self, epoch, logs=None):
self.episode_finished = False
def on_batch_begin(self, batch, logs=None):
pass
def on_batch_end(self, batch, logs=None):
pass
def on_epoch_end(self, epoch, logs=None):
self.episode_finished = True
if self.end_train_before_episode_end: self.model.stop_training = True
def on_train_end(self, logs=None):
if not self.episode_finished: self.end_train_before_episode_end = True
self.model.stop_training = False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment