Skip to content

Instantly share code, notes, and snippets.

@devforfu
Created December 7, 2018 09:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save devforfu/3bdefe1e09470da01216850f43bf0f85 to your computer and use it in GitHub Desktop.
Save devforfu/3bdefe1e09470da01216850f43bf0f85 to your computer and use it in GitHub Desktop.
A better training loop
def train(model, opt, phases, callbacks=None, epochs=1, device=default_device, loss_fn=F.nll_loss):
model.to(device)
cb = callbacks
cb.training_started(phases=phases, optimizer=opt)
for epoch in range(1, epochs + 1):
cb.epoch_started(epoch=epoch)
for phase in phases:
n = len(phase.loader)
cb.phase_started(phase=phase, total_batches=n)
is_training = phase.grad
model.train(is_training)
for batch in phase.loader:
phase.batch_index += 1
cb.batch_started(phase=phase, total_batches=n)
x, y = place_and_unwrap(batch, device)
with torch.set_grad_enabled(is_training):
cb.before_forward_pass()
out = model(x)
cb.after_forward_pass()
loss = loss_fn(out, y)
if is_training:
opt.zero_grad()
cb.before_backward_pass()
loss.backward()
cb.after_backward_pass()
opt.step()
phase.batch_loss = loss.item()
cb.batch_ended(phase=phase, output=out, target=y)
cb.phase_ended(phase=phase)
cb.epoch_ended(phases=phases, epoch=epoch)
cb.training_ended(phases=phases)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment