Skip to content

Instantly share code, notes, and snippets.

@piEsposito
Created April 28, 2020 18:04
Show Gist options
  • Save piEsposito/6eb11b8a5f2023ad750f6e8fae6bf882 to your computer and use it in GitHub Desktop.
Save piEsposito/6eb11b8a5f2023ad750f6e8fae6bf882 to your computer and use it in GitHub Desktop.
epoch_bar = tqdm(range(10),
desc="Training",
position=0,
total=2)
acc = 0
for epoch in epoch_bar:
batch_bar = tqdm(enumerate(train_loader),
desc="Epoch: {}".format(str(epoch)),
position=1,
total=len(train_loader))
for i, (datapoints, labels) in batch_bar:
optimizer.zero_grad()
preds = classifier(datapoints.long())
loss = criterion(preds, labels)
loss.backward()
optimizer.step()
if (i + 1) % 500 == 0:
preds = classifier(X_test)
acc = (preds.argmax(dim=1) == y_test).float().mean().cpu().item()
batch_bar.set_postfix(loss=loss.cpu().item(),
accuracy="{:.2f}".format(acc),
epoch=epoch)
batch_bar.update()
epoch_bar.set_postfix(loss=loss.cpu().item(),
accuracy="{:.2f}".format(acc),
epoch=epoch)
epoch_bar.update()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment