Skip to content

Instantly share code, notes, and snippets.

@Microsheep
Created March 17, 2022 04:18
Show Gist options
  • Save Microsheep/f52e1b613857e8aea187ac1e05523f81 to your computer and use it in GitHub Desktop.
Save Microsheep/f52e1b613857e8aea187ac1e05523f81 to your computer and use it in GitHub Desktop.
Multiprocessing wrapper with chunking and ETA using tqdm
""" MP_TQDM v4.1 by twmicrosheep
This module is a wrapper for easy multiprocessing with tqdm
"""
import math
import itertools
import multiprocessing
from functools import wraps
from typing import List, Dict, Any, Callable, Iterable, Optional
from tqdm.auto import tqdm
ParamList = Dict[str, Any]
def mp_tqdm(func: Callable,
args: Iterable[ParamList],
args_len: Optional[int] = None,
shared: Optional[ParamList] = None,
task_size: int = 1,
process_cnt: int = 1,
ordered: bool = False,
reset: bool = True) -> List[Any]:
"""This function parallelize the workload using multiprocessing
Args:
func: A function that is decorated by MP_TQDM_WORKER
args: Iterable of parameters for each task
args_len: Length of iterable of parameters (Optional if args is not a generator)
shared: Optional shared parameters for each task
task_size: Size of a single batch
process_cnt: Number of worker processes
ordered: Return the output in order
reset: Do workers need to be reset between batches
Returns:
Returns a list of worker function returns
Ordered according to original args if the ordered parameter is True
"""
def grouper(iterable, n):
iterable = iter(iterable)
def add_param(x):
return (process_cnt, shared, x)
return iter(lambda: add_param(list(itertools.islice(iterable, n))), add_param([]))
rets: List[Any] = []
with multiprocessing.Pool(process_cnt, maxtasksperchild=1 if reset else None) as p:
# The master process tqdm bar is at Position 0
if args_len is None:
try:
args_len = len(args) # type: ignore
except Exception:
args_len = None
total_chunks = None if args_len is None else math.ceil(args_len / task_size)
mapmethod = p.imap if ordered else p.imap_unordered
for ret in tqdm(mapmethod(func, grouper(args, task_size)),
total=total_chunks, dynamic_ncols=True):
rets += ret
return rets
def mp_tqdm_worker(func: Callable) -> Callable:
"""This is a decorator function to decorate worker functions
Args:
Callable: A Callable that takes in shared args and a single task in list of args
and do necessary processing before returning results
Note:
Do not include tqdm in worker callable
Returns:
Returns a List of Function Returns which is in order of original Args
"""
@wraps(func)
def d_func(args):
process_cnt, shared, argset = args
shared = shared if shared is not None else {}
# pylint: disable=protected-access
worker_id = (multiprocessing.current_process()._identity[0] - 1) % process_cnt + 1
# pylint: enable=protected-access
rets = []
for arg in argset:
rets.append(func(worker_id=worker_id, **shared, **arg))
return rets
return d_func
@Microsheep
Copy link
Author

Quickstart

This module is not meant to be really efficient in a lot of places.
The goal is to create an easy-to-use module that could do basic parallelization with chunking and ETA.

Making it easy-to-use would let us use parallelization much more frequently and would be faster than nothing.

For example:
    No parallelization -> 1x
    Efficient parallelization -> 3.7x
    Easy-to-use parallelization -> 3.2x

Basic Usage

The example here is to calculate a list of numbers multiplied by a shared constant.
(Eg. list=[1,2,3] and constant=5 -> output=[5,10,15])

Write the worker function that completes one task

  • Add the decorator @mp_tqdm_worker
  • Add **kwargs at the end as the wrapper injects additional parameters, worker_id
@mp_tqdm_worker
def my_job(c, x , **kwargs):
    return c*x

Due to a BUG in tqdm, do not use tqdm inside the worker function. (tqdm/tqdm#627)

Call mp_tqdm to run the parallelization and show the progress with ETA using tqdm

Using a list of parameters

# All the arguments (These are the jobs to finish)
my_args = [{"x": i} for i in range(100)]

# Run all the jobs and get the return values in result
result = mp_tqdm(my_job,
                 args=my_args,
                 shared={"c": 5},
                 task_size=10,
                 process_cnt=4,
                 ordered=True)

Using a generator

This could prevent the generation of all parameters at the start that uses a lot of memory in some cases.
args_len is needed for tqdm to calculate ETA as the length of a generator is not directly available.

# Write a generator to generate all the arguments (These are the jobs to finish)
def gen_arg(x):
    for i in range(x):
        yield {"x": i}

# Run all the jobs and get the return values in result (args_len is needed for ETA)
result = mp_tqdm(my_job,
                 args=gen_arg(100),
                 args_len=100,
                 shared={"c": 5},
                 task_size=10,
                 process_cnt=4,
                 ordered=True)  

A good task_size

There is always some overhead when we start a new batch of tasks in a new process.
So we would want the worker to work on a certain amount of tasks each time before getting new tasks instead of running one task and returning for more.

Having ETAs for long-lasting jobs is important and lets us know that the thing is still running.
Each time a batch of jobs (task_size) is finished by a process, we tick the progress bar of tqdm.
To make the progress bar and ETA accurate, we would want it to be updated frequently.
This means that we would not want the task_size to be really large.

Finding a good balance would be needed. Using the module code, we simply write how "one" task could be done, set the task_size, and the mp_tqdm would group the jobs together for us.

Advanced Usage

Using the worker_id

The worker function gets an additional argument worker_id, which is a number between 1 and process_cnt.
This worker_id might be helpful in some special use cases.
Beware that there might be the same worker_id in two different processes at the same time if the previous worker with that id hasn't finished but another worker has finished. (See the implementation in code to understand why this might happen.)

@mp_tqdm_worker
def my_job(config, worker_id, **kwargs):
    setup_gpu(worker_id % 4)
    return run_experiment(config)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment