Skip to content

Instantly share code, notes, and snippets.

@gngdb
Created August 17, 2018 13:54
Show Gist options
  • Save gngdb/b5586040ee1430936aebc78c41be2905 to your computer and use it in GitHub Desktop.
Save gngdb/b5586040ee1430936aebc78c41be2905 to your computer and use it in GitHub Desktop.
Log training using your progress bar object; after `pbar.update` call `pbar.log` with whatever you want to store as kwargs.
from tqdm import tqdm
class TrainingProgress(tqdm):
"""Make the progress bar store progress as it displays it when the log method is called."""
def log(self, *args, **kwargs):
if 'trace' not in self.__dict__.keys():
self.trace = {}
# store anything with an integer or float datatype in the trace dictionary
# with global index as the key
for k in kwargs:
if type(kwargs[k]) == float or type(kwargs[k]) == int:
if k not in self.trace:
self.trace[k] = [] # start logging this kwarg
self.trace[k].append((self.n, kwargs[k]))
# unpack all the most recent values stored into kwargs
# so they're always displayed
for k in self.trace:
if k not in kwargs:
kwargs[k] = self.trace[k][-1][1]
try:
super(TrainingProgress, self).set_postfix(*args, **kwargs)
except Exception as exc:
# cannot catch KeyboardInterrupt when using manual tqdm
# as the interrupt will most likely happen on another statement
self.sp(bar_style='danger')
raise exc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment