Skip to content

Instantly share code, notes, and snippets.

@devforfu
Last active December 7, 2018 07:27
Show Gist options
  • Save devforfu/d66380c3c1aafafb239ecb3a49eccd6a to your computer and use it in GitHub Desktop.
Save devforfu/d66380c3c1aafafb239ecb3a49eccd6a to your computer and use it in GitHub Desktop.
Simple training loop pseudocode
model = create_model(params)
phases = create_train_valid_data()
opt = optim.SGD(model.params, lr=1e-3)
model.to(device)
for epoch in range(1, epochs + 1):
for phase in phases:
n = len(phase.loader)
is_training = phase.grad
model.train(is_training)
for batch in phase.loader:
x, y = place_and_unwrap(batch, device)
with torch.set_grad_enabled(is_training):
out = model(x)
loss = loss_fn(out, y)
if is_training:
opt.zero_grad()
loss.backward()
opt.step()
phase.batch_loss = loss.item()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment