Skip to content

Instantly share code, notes, and snippets.

@mjhong0708
Last active December 30, 2021 06:30
Show Gist options
  • Save mjhong0708/8c9302b902a951c5b81aae247fa55faf to your computer and use it in GitHub Desktop.
Save mjhong0708/8c9302b902a951c5b81aae247fa55faf to your computer and use it in GitHub Desktop.
Show keras-like progress bar on pytorch
def make_progbar(num_done, total_length):
num_done_syms = int((num_done + 1) / total_length * 30)
if total_length > 1 and num_done == 0:
progbar = "[" + ">".ljust(30, ".") + "]"
elif num_done < total_length - 1:
progbar = "[" + "=" * num_done_syms + ">".ljust(30 - num_done_syms, ".") + "]"
else:
progbar = "[" + ">".rjust(30, "=") + "]"
return progbar
# example function
def fit_model(model, dataloader, optimizer, loss_fn, num_epochs=10):
num_batches = len(dataloader)
for epoch in range(num_epochs):
batch_losses = []
print(f"Epoch {epoch + 1}/{num_epochs}")
for i, (x, y) in enumerate(loader):
optimizer.zero_grad()
loss = loss_fn(model(x), y)
loss.backward()
optimizer.step()
batch_losses.append(loss.cpu().detach())
end = "\r" if i < num_batches - 1 else None
curr_batch_idx = str(i + 1).rjust(len(str(num_batches)))
progbar = make_progbar(i, num_batches)
mean_loss = sum(batch_losses) / (i + 1)
print(f"{curr_batch_idx}/{num_batches} {progbar} - loss: {mean_loss:.3f}", end=end)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment