Skip to content

Instantly share code, notes, and snippets.

@cflamant
Created December 11, 2019 07:51
Show Gist options
  • Save cflamant/62b68d8dcd2c7d3b1a739350ffb5c6c3 to your computer and use it in GitHub Desktop.
Save cflamant/62b68d8dcd2c7d3b1a739350ffb5c6c3 to your computer and use it in GitHub Desktop.
PyTorch Training loop example using tqdm to monitor progress (won't run by itself, needs to be in a class)
## Example training loop from within a class which contains
## a PyTorch nn.Module() under the variable self.net()
optimizer = optim.SGD(self.net.parameters(), lr=0.001)
train_losses = []
with tqdm.trange(num_batch) as batches:
for b in batches:
x = self.getdata()
# Call network for forward pass
losses = self.net(x)
# clear gradients
self.optimizer.zero_grad()
# compute loss
loss = torch.mean(losses)
# backprop, step
loss.backward(retain_graph=False)
self.optimizer.step()
# record loss
loss_val = loss.cpu().item()
if b % loss_freq == 0:
train_losses.append(loss_val)
batches.set_postfix(loss=f'{loss_val:.2e}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment