Skip to content

Instantly share code, notes, and snippets.

@raytroop
Last active October 23, 2018 07:31
Show Gist options
  • Save raytroop/6de9912a8eec60e0cb6f92199f6de785 to your computer and use it in GitHub Desktop.
Save raytroop/6de9912a8eec60e0cb6f92199f6de785 to your computer and use it in GitHub Desktop.
tqdm pattern

tqdm is very versatile and can be used in a number of ways. The three main ones are given below.

Iterable-based

Wrap tqdm() around any iterable:

    text = ""
    for char in tqdm(["a", "b", "c", "d"]):
        text = text + char

trange(i) is a special optimised instance of tqdm(range(i)):

    for i in trange(100):
        pass

Instantiation outside of the loop allows for manual control over tqdm():

    pbar = tqdm(["a", "b", "c", "d"])
    for char in pbar:
        pbar.set_description("Processing %s" % char)
Manual

Manual control on tqdm() updates by using a with statement:

    with tqdm(total=100) as pbar:
        for i in range(10):
            pbar.update(10)

If the optional variable total (or an iterable with len()) is provided, predictive stats are displayed.

with is also optional (you can just assign tqdm() to a variable, but in this case don't forget to del or close() at the end:

    pbar = tqdm(total=100)
    for i in range(10):
        pbar.update(10)
    pbar.close()
Concrete example
num_steps = 10
loss_val = 0.1
# Use tqdm for progress bar
t = trange(num_steps)
for i in t:
    time.sleep(0.1)
    loss_val *= 0.1
    # Log the loss in the tqdm progress bar
    t.set_postfix(loss='{:05.3f}'.format(loss_val))


def my_generator0(n):
    for i in range(1, 100):
        time.sleep(0.1)
        yield i
        if i > n-1:
            return

dataloader = my_generator0(10)
with tqdm(total=10) as t:
    for labels_batch in dataloader:
        loss_avg = labels_batch * 0.1
        # Log the loss in the tqdm progress bar
        t.set_postfix(loss='{:05.3f}'.format(loss_avg))
        t.update()

100%|██████| 10/10 [00:01<00:00, 9.97it/s, loss=0.000]

100%|██████| 10/10 [00:01<00:00, 9.97it/s, loss=1.000]

@raytroop
Copy link
Author

set_description vs. set_postfix

from tqdm import tqdm
import numpy as np
import time
ax = np.random.normal(size=(100, 10))

dataset_iterator = tqdm(ax, total=len(ax))
for i, data in enumerate(dataset_iterator):

    status = f'{i}: {np.mean(data):.3f}'
    dataset_iterator.set_description(status)
    dataset_iterator.set_postfix(std='{:05.3f}'.format(np.std(data)))
    time.sleep(0.1)

"""
$ python logtqdm.py 
99: -0.424: 100%|██████████████████| 100/100 [00:10<00:00,  9.90it/s, std=0.479]
"""

@raytroop
Copy link
Author

raytroop commented Oct 23, 2018

trange

print("training %s..." % args.model)
pbar_epoch = trange(start_epoch, args.max_epochs)

# import cProfile
# pr = cProfile.Profile()
# pr.enable()

for epoch in pbar_epoch:
    if args.lr_scheduler != 'plateau':
        if args.lr_scheduler == 'clr':
            if epoch % args.lr_scheduler_step_size == 0:
                # reset best loss and metric for every cycle
                best_loss = 1e10
                best_metric = 0
            lr_scheduler.step(epoch % args.lr_scheduler_step_size)
        else:
            lr_scheduler.step()

    train_epoch_loss, train_epoch_metric, train_epoch_epoch_accuracy = train(epoch, phase='train')
    valid_epoch_loss, valid_epoch_metric, valid_epoch_epoch_accuracy = train(epoch, phase='valid')

    if args.lr_scheduler == 'plateau':
        lr_scheduler.step(metrics=valid_epoch_loss)

    pbar_epoch.set_postfix({'lr': '%.02e' % get_lr(),
                            'train': '%.03f/%.03f/%.03f' % (
                            train_epoch_loss, train_epoch_metric, train_epoch_epoch_accuracy),
                            'val': '%.03f/%.03f/%.03f' % (
                            valid_epoch_loss, valid_epoch_metric, valid_epoch_epoch_accuracy),
                            'best val': '%.03f/%.03f/%.03f' % (best_loss, best_metric, best_accuracy)},
                           refresh=False)

@raytroop
Copy link
Author

tqdm(self, iterable=None, desc=None, total=None, leave=True, file=None, ncols=None, mininterval=0.1, maxinterval=10.0,
    miniters=None, ascii=None, disable=False, unit='it', unit_scale=False, dynamic_ncols=False,
    smoothing=0.3, bar_format=None, initial=0, position=None, postfix=None, unit_divisor=1000, gui=False, **kwargs)

Decorate an iterable object, returning an iterator which acts exactly like the original iterable,
but prints a dynamically updating progressbar every time a value is requested.


total : int, optional
    The number of expected iterations. If unspecified, len(iterable) is used if possible.
    As a last resort, only basic progress statistics are displayed (no ETA, no progressbar)

disable : bool, optional
    Whether to disable the entire progressbar wrapper [default: False]. If set to None, disable on non-TTY.

unit : str, optional
    String that will be used to define the unit of each iteration [default: it].

unit_scale : bool or int or float, optional
    If 1 or True, the number of iterations will be reduced/scaled automatically and a metric prefix following
    the International System of Units standard will be added (kilo, mega, etc.) [default: False].
    If any other non-zero number, will scale total and n.
import time
from tqdm import tqdm
import numpy as np
np.random.seed(42)

batch_size = 4
dataloader = np.random.randn(20, batch_size, 10)
pbar_disable = False #  Whether to disable the entire progressbar wrapper
pbar = tqdm(dataloader, unit="images",
            unit_scale=batch_size, disable=pbar_disable)
for batch in pbar:
    time.sleep(0.5)

    # update the progress bar
    pbar.set_postfix({
        'loss': "%.05f" % (np.mean(batch))
    })

total(default) = len(iterable) = 20
Here, 80/80 <--- total * unit_scale = 20 * 4 = 80
unit replace it
100%|█████████████████████████████| 80/80 [00:10<00:00, 1.99images/s, loss=0.03666]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment