Skip to content

Instantly share code, notes, and snippets.

@johnhw
Created June 16, 2019 10:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save johnhw/d53c1a4ae4744858dc91e62bbd3eaf88 to your computer and use it in GitHub Desktop.
Save johnhw/d53c1a4ae4744858dc91e62bbd3eaf88 to your computer and use it in GitHub Desktop.
from tqdm.auto import trange, tqdm
def tqdm_train_loop(train_fn, num_epochs, train_loader, val_fn=None):
# each epoch
with trange(num_epochs, unit="epoch") as epoch_t:
for epoch in epoch_t:
# each batch
with tqdm(train_loader, leave=False, unit='batch', postfix='') as batch_t:
for batch_idx, (features, targets) in enumerate(batch_t):
cost = train_fn(features, labels)
batch_t.postfix = f"cost {cost.item():.2f}"
batch_t.update()
# compute validation loss, if possible
if val_fn:
val_cost = val_fn(features, labels)
epoch_t.postfix = f"train {cost.item():.2f}, val {val_cost.item():.2f} "
else:
epoch_t.postfix = f"train {cost.item():.2f}"
epoch_t.update()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment