Skip to content

Instantly share code, notes, and snippets.

@edwardeasling
Created March 29, 2019 05:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save edwardeasling/17a28a5f81463a341aaa444096eb87f5 to your computer and use it in GitHub Desktop.
Save edwardeasling/17a28a5f81463a341aaa444096eb87f5 to your computer and use it in GitHub Desktop.
BP1 - 04
def one_batch(xb, yb, cb):
if not cb.begin_batch(xb,yb): return
loss = cb.learn.loss_func(cb.learn.model(xb), yb)
if not cb.after_loss(loss): return
loss.backward()
if cb.after_backward(): cb.learn.opt.step()
if cb.after_step(): cb.learn.opt.zero_grad()
def all_batches(dl, cb):
for xb,yb in dl:
one_batch(xb, yb, cb)
if cb.do_stop(): return
def fit(epochs, learn, cb):
if not cb.begin_fit(learn): return
for epoch in range(epochs):
if not cb.begin_epoch(epoch): continue
all_batches(learn.data.train_dl, cb) ###
if cb.begin_validate():
with torch.no_grad(): all_batches(learn.data.valid_dl, cb)
if cb.do_stop() or not cb.after_epoch(): break
cb.after_fit()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment