Skip to content

Instantly share code, notes, and snippets.

@edwardeasling
Created March 29, 2019 06:30
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save edwardeasling/e62365c0c5a202e124c495fcd25c561a to your computer and use it in GitHub Desktop.
Save edwardeasling/e62365c0c5a202e124c495fcd25c561a to your computer and use it in GitHub Desktop.
BP1 - 06
class CallbackHandler():
def __init__(self,cbs=None):
self.cbs = cbs if cbs else []
def begin_fit(self, learn):
self.learn,self.in_train = learn,True
learn.stop = False
res = True
for cb in self.cbs: res = res and cb.begin_fit(learn)
return res
def after_fit(self):
res = not self.in_train
for cb in self.cbs: res = res and cb.after_fit()
return res
def begin_epoch(self, epoch):
learn.model.train()
self.in_train=True
res = True
for cb in self.cbs: res = res and cb.begin_epoch(epoch)
return res
def begin_validate(self):
self.learn.model.eval()
self.in_train=False
res = True
for cb in self.cbs: res = res and cb.begin_validate()
return res
def after_epoch(self):
res = True
for cb in self.cbs: res = res and cb.after_epoch()
return res
def begin_batch(self, xb, yb):
res = True
for cb in self.cbs: res = res and cb.begin_batch(xb, yb)
return res
def after_loss(self, loss):
res = self.in_train
for cb in self.cbs: res = res and cb.after_loss(loss)
return res
def after_backward(self):
res = True
for cb in self.cbs: res = res and cb.after_backward()
return res
def after_step(self):
res = True
for cb in self.cbs: res = res and cb.after_step()
return res
def do_stop(self):
try: return learn.stop
finally: learn.stop = False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment