Created
August 17, 2018 13:54
-
-
Save gngdb/b5586040ee1430936aebc78c41be2905 to your computer and use it in GitHub Desktop.
Log training using your progress bar object; after `pbar.update` call `pbar.log` with whatever you want to store as kwargs.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tqdm import tqdm | |
class TrainingProgress(tqdm): | |
"""Make the progress bar store progress as it displays it when the log method is called.""" | |
def log(self, *args, **kwargs): | |
if 'trace' not in self.__dict__.keys(): | |
self.trace = {} | |
# store anything with an integer or float datatype in the trace dictionary | |
# with global index as the key | |
for k in kwargs: | |
if type(kwargs[k]) == float or type(kwargs[k]) == int: | |
if k not in self.trace: | |
self.trace[k] = [] # start logging this kwarg | |
self.trace[k].append((self.n, kwargs[k])) | |
# unpack all the most recent values stored into kwargs | |
# so they're always displayed | |
for k in self.trace: | |
if k not in kwargs: | |
kwargs[k] = self.trace[k][-1][1] | |
try: | |
super(TrainingProgress, self).set_postfix(*args, **kwargs) | |
except Exception as exc: | |
# cannot catch KeyboardInterrupt when using manual tqdm | |
# as the interrupt will most likely happen on another statement | |
self.sp(bar_style='danger') | |
raise exc |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment