This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CallbackHandler(): | |
def __init__(self,cbs=None): | |
self.cbs = cbs if cbs else [] | |
def begin_fit(self, learn): | |
self.learn,self.in_train = learn,True | |
learn.stop = False | |
res = True | |
for cb in self.cbs: res = res and cb.begin_fit(learn) | |
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Callback(): | |
def begin_fit(self, learn): | |
self.learn = learn | |
return True | |
def after_fit(self): return True | |
def begin_epoch(self, epoch): | |
self.epoch=epoch | |
return True | |
def begin_validate(self): return True | |
def after_epoch(self): return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def one_batch(xb, yb, cb): | |
if not cb.begin_batch(xb,yb): return | |
loss = cb.learn.loss_func(cb.learn.model(xb), yb) | |
if not cb.after_loss(loss): return | |
loss.backward() | |
if cb.after_backward(): cb.learn.opt.step() | |
if cb.after_step(): cb.learn.opt.zero_grad() | |
def all_batches(dl, cb): | |
for xb,yb in dl: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CallbackHandler(): | |
def __init__(self,cbs=None): | |
self.cbs = cbs if cbs else [] | |
def begin_fit(self, learn): | |
self.learn,self.in_train = learn,True | |
learn.stop = False | |
res = True | |
for cb in self.cbs: res = res and cb.begin_fit(learn) | |
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class BatchCounter(Callback): | |
def begin_epoch(self, epoch): | |
self.epoch=epoch | |
self.batch_counter = 1 | |
return True | |
def after_step(self): | |
self.batch_counter += 1 | |
if self.batch_counter % 200 == 0: print(f'Batch {self.batch_counter} completed') | |
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def one_batch(xb,yb): | |
pred = model(xb) | |
loss = loss_func(pred, yb) | |
loss.backward() | |
opt.step() | |
opt.zero_grad() | |
def fit(): | |
for epoch in range(epochs): | |
for b in train_dl: one_batch(*b) |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
NewerOlder