Skip to content

Instantly share code, notes, and snippets.

@omarfsosa
Created September 5, 2022 08:51
Show Gist options
  • Save omarfsosa/2dca44aed9b699d935bd479e6c051180 to your computer and use it in GitHub Desktop.
Save omarfsosa/2dca44aed9b699d935bd479e6c051180 to your computer and use it in GitHub Desktop.
Custom tqdm progress bar in pytorch training loop
from tqdm.auto import tqdm
# The example below is for a classification model
# with the last layer of the model producing the logits
# --i.e. without a softmax layer
def train_sinle_epoch(dataloader, model, loss_fn, optimizer):
correct_total = 0
size_total = 0
with tqdm(dataloader, unit="batch") as tepoch:
for x, y in tepoch:
tepoch.set_description("Progress")
# x, y = x.to(device), y.to(device)
# -- Forward pass
logits = model(x)
# -- Backprop
optimizer.zero_grad()
loss = loss_fn(logits, y)
loss.backward()
optimizer.step()
# -- Compute metrics
y_pred = logits.argmax(dim=1, keepdim=True).squeeze()
num_correct = (y_pred == y).sum().item()
correct_total += num_correct
size_total += len(y)
accuracy = correct_total / size_total
# -- Update the progress bar values
tepoch.set_postfix(
loss=loss.item(),
acc=format(accuracy, "3.2%"),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment