Skip to content

Instantly share code, notes, and snippets.

@adamoudad
Last active August 27, 2023 12:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save adamoudad/3a519827c54704822cd69cf490bc69b1 to your computer and use it in GitHub Desktop.
Save adamoudad/3a519827c54704822cd69cf490bc69b1 to your computer and use it in GitHub Desktop.
model.train()
for epoch in range(1, 5):
with tqdm(train_loader, unit="batch") as tepoch:
for data, target in tepoch:
tepoch.set_description(f"Epoch {epoch}")
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
predictions = output.argmax(dim=1, keepdim=True).squeeze()
loss = F.nll_loss(output, target)
correct = (predictions == target).sum().item()
accuracy = correct / batch_size
loss.backward()
optimizer.step()
tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment