Last active
January 8, 2022 21:16
-
-
Save jrbergen/86756afbbd73ce34740948d82d12c023 to your computer and use it in GitHub Desktop.
multiprocessing.pool-like context manager which supports updating a tqdm progress bar from within.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import uuid | |
from typing import Callable, Optional, Iterable | |
import psutil | |
import tqdm | |
import multiprocessing as mp | |
class MultiprocessingPoolTQDM: | |
ALL_COMPLETED_SENTINEL = "__ALLDONE" | |
TQDM_UPDATE_SIGNAL: str = "__TQDM_UPDATE" | |
RESULT_SIGNAL: str = "__RESULT" | |
ERROR_SIGNAL: str = "__ERROR" | |
LOOP_PAUSE_SECONDS: float = .1 | |
P_PENDING: int = 0 | |
P_BUSY: int = 1 | |
P_DONE: int = 2 | |
P_ERRORED: int = 3 | |
def __init__(self, | |
nproc: Optional[int] = None, | |
physical_cores: bool = False, | |
debug: bool = False, | |
**tqdmkwargs): | |
self.total_iter: Optional[int] = tqdmkwargs.get('total', None) | |
self.debug: bool = debug | |
self.nproc = psutil.cpu_count(logical=not physical_cores) if nproc is None else nproc | |
self.nproc_nonlistener = self.nproc - 1 | |
self.queue: mp.Queue = mp.Queue() | |
self.queueproc: mp.Process = mp.Process(target=MultiprocessingPoolTQDM._listener, | |
kwargs={**{'queue': self.queue, | |
'debug': self.debug}, | |
**tqdmkwargs | |
} | |
) | |
self.processes: list[mp.Process] = [] | |
self.p_status: dict[str, int] = dict() | |
self.queueproc.start() | |
self.results: list = [] | |
def append(self, func: Callable[[mp.Queue, ...], any], **kwargs): | |
process = mp.Process(target=func, | |
kwargs={**kwargs, | |
**{'queue': self.queue} | |
} | |
) | |
process.name = uuid.uuid4().hex | |
self.processes.append(process) | |
self.p_status[process.name] = MultiprocessingPoolTQDM.P_PENDING | |
def run_and_get_results(self) -> any: | |
P_PENDING = MultiprocessingPoolTQDM.P_PENDING | |
P_BUSY = MultiprocessingPoolTQDM.P_BUSY | |
P_DONE = MultiprocessingPoolTQDM.P_DONE | |
P_ERRORED = MultiprocessingPoolTQDM.P_ERRORED | |
while True: | |
for proc in self.processes: | |
if proc.exitcode is not None: | |
if proc.exitcode == 0 and self.p_status[proc.name] != P_DONE: | |
self.p_status[proc.name] = P_DONE | |
elif proc.exitcode != 0 and self.p_status[proc.name] != P_ERRORED: | |
self.p_status[proc.name] = P_ERRORED | |
n_busy = len([pstate | |
for pstate in self.p_status.values() | |
if pstate == P_BUSY] | |
) | |
n_free = self.nproc_nonlistener - n_busy | |
for _ in range(n_free): | |
for proc in self.processes: | |
if n_free > 0 and self.p_status[proc.name] == P_PENDING: | |
proc.start() | |
if self.debug: | |
print("Asyncloop:", f"Started {proc.name}") | |
self.p_status[proc.name] = P_BUSY | |
n_free -= 1 | |
if n_free == 0: | |
break | |
if all(pstate == P_DONE for pstate in self.p_status.values()): | |
if self.debug: | |
print("Asyncloop:", "ALL DONE") | |
break | |
time.sleep(MultiprocessingPoolTQDM.LOOP_PAUSE_SECONDS) | |
self.queue.put(MultiprocessingPoolTQDM.ALL_COMPLETED_SENTINEL) | |
return self.queue.get()[MultiprocessingPoolTQDM.RESULT_SIGNAL] | |
@staticmethod | |
def _listener(queue: mp.Queue, | |
debug: bool = False, | |
**tqdmkwargs | |
) -> None: | |
try: | |
pbar = tqdm.tqdm(**tqdmkwargs) | |
except tqdm.std.TqdmKeyError as err: | |
queue.put({MultiprocessingPoolTQDM.ERROR_SIGNAL: err}) | |
return | |
results = [] | |
while True: | |
val = queue.get() | |
if debug and val != MultiprocessingPoolTQDM.TQDM_UPDATE_SIGNAL: | |
print("Listener:", str(val)) | |
if isinstance(val, list): | |
results.extend(val) | |
elif isinstance(val, str): | |
if val == MultiprocessingPoolTQDM.TQDM_UPDATE_SIGNAL: | |
pbar.update() | |
elif val == MultiprocessingPoolTQDM.ALL_COMPLETED_SENTINEL: | |
queue.put({MultiprocessingPoolTQDM.RESULT_SIGNAL: results}) | |
break | |
else: | |
results.append(val) | |
else: | |
results.append(val) | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
try: | |
if exc_type is None: | |
for proc in self.processes: | |
if proc is not None and proc.is_alive(): | |
proc.join() | |
elif exc_type == tqdm.std.TqdmKeyError: | |
raise Exception("Invalid TQDM arguments...") from exc_val | |
else: | |
raise | |
finally: | |
self._close() | |
def _close(self): | |
if hasattr(self, 'processes') and isinstance(processes, Iterable): | |
for proc in self.processes: | |
if proc is not None and proc.is_alive(): | |
proc.close() | |
if hasattr(self, 'queue') and self.queue is not None: | |
self.queue.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment