Created
September 5, 2022 08:51
-
-
Save omarfsosa/2dca44aed9b699d935bd479e6c051180 to your computer and use it in GitHub Desktop.
Custom tqdm progress bar in pytorch training loop
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tqdm.auto import tqdm | |
# The example below is for a classification model | |
# with the last layer of the model producing the logits | |
# --i.e. without a softmax layer | |
def train_sinle_epoch(dataloader, model, loss_fn, optimizer): | |
correct_total = 0 | |
size_total = 0 | |
with tqdm(dataloader, unit="batch") as tepoch: | |
for x, y in tepoch: | |
tepoch.set_description("Progress") | |
# x, y = x.to(device), y.to(device) | |
# -- Forward pass | |
logits = model(x) | |
# -- Backprop | |
optimizer.zero_grad() | |
loss = loss_fn(logits, y) | |
loss.backward() | |
optimizer.step() | |
# -- Compute metrics | |
y_pred = logits.argmax(dim=1, keepdim=True).squeeze() | |
num_correct = (y_pred == y).sum().item() | |
correct_total += num_correct | |
size_total += len(y) | |
accuracy = correct_total / size_total | |
# -- Update the progress bar values | |
tepoch.set_postfix( | |
loss=loss.item(), | |
acc=format(accuracy, "3.2%"), | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment