Skip to content

Instantly share code, notes, and snippets.

@jrbergen
Last active January 8, 2022 21:16
Show Gist options
  • Save jrbergen/86756afbbd73ce34740948d82d12c023 to your computer and use it in GitHub Desktop.
Save jrbergen/86756afbbd73ce34740948d82d12c023 to your computer and use it in GitHub Desktop.
multiprocessing.pool-like context manager which supports updating a tqdm progress bar from within.
import time
import uuid
from typing import Callable, Optional, Iterable
import psutil
import tqdm
import multiprocessing as mp
class MultiprocessingPoolTQDM:
ALL_COMPLETED_SENTINEL = "__ALLDONE"
TQDM_UPDATE_SIGNAL: str = "__TQDM_UPDATE"
RESULT_SIGNAL: str = "__RESULT"
ERROR_SIGNAL: str = "__ERROR"
LOOP_PAUSE_SECONDS: float = .1
P_PENDING: int = 0
P_BUSY: int = 1
P_DONE: int = 2
P_ERRORED: int = 3
def __init__(self,
nproc: Optional[int] = None,
physical_cores: bool = False,
debug: bool = False,
**tqdmkwargs):
self.total_iter: Optional[int] = tqdmkwargs.get('total', None)
self.debug: bool = debug
self.nproc = psutil.cpu_count(logical=not physical_cores) if nproc is None else nproc
self.nproc_nonlistener = self.nproc - 1
self.queue: mp.Queue = mp.Queue()
self.queueproc: mp.Process = mp.Process(target=MultiprocessingPoolTQDM._listener,
kwargs={**{'queue': self.queue,
'debug': self.debug},
**tqdmkwargs
}
)
self.processes: list[mp.Process] = []
self.p_status: dict[str, int] = dict()
self.queueproc.start()
self.results: list = []
def append(self, func: Callable[[mp.Queue, ...], any], **kwargs):
process = mp.Process(target=func,
kwargs={**kwargs,
**{'queue': self.queue}
}
)
process.name = uuid.uuid4().hex
self.processes.append(process)
self.p_status[process.name] = MultiprocessingPoolTQDM.P_PENDING
def run_and_get_results(self) -> any:
P_PENDING = MultiprocessingPoolTQDM.P_PENDING
P_BUSY = MultiprocessingPoolTQDM.P_BUSY
P_DONE = MultiprocessingPoolTQDM.P_DONE
P_ERRORED = MultiprocessingPoolTQDM.P_ERRORED
while True:
for proc in self.processes:
if proc.exitcode is not None:
if proc.exitcode == 0 and self.p_status[proc.name] != P_DONE:
self.p_status[proc.name] = P_DONE
elif proc.exitcode != 0 and self.p_status[proc.name] != P_ERRORED:
self.p_status[proc.name] = P_ERRORED
n_busy = len([pstate
for pstate in self.p_status.values()
if pstate == P_BUSY]
)
n_free = self.nproc_nonlistener - n_busy
for _ in range(n_free):
for proc in self.processes:
if n_free > 0 and self.p_status[proc.name] == P_PENDING:
proc.start()
if self.debug:
print("Asyncloop:", f"Started {proc.name}")
self.p_status[proc.name] = P_BUSY
n_free -= 1
if n_free == 0:
break
if all(pstate == P_DONE for pstate in self.p_status.values()):
if self.debug:
print("Asyncloop:", "ALL DONE")
break
time.sleep(MultiprocessingPoolTQDM.LOOP_PAUSE_SECONDS)
self.queue.put(MultiprocessingPoolTQDM.ALL_COMPLETED_SENTINEL)
return self.queue.get()[MultiprocessingPoolTQDM.RESULT_SIGNAL]
@staticmethod
def _listener(queue: mp.Queue,
debug: bool = False,
**tqdmkwargs
) -> None:
try:
pbar = tqdm.tqdm(**tqdmkwargs)
except tqdm.std.TqdmKeyError as err:
queue.put({MultiprocessingPoolTQDM.ERROR_SIGNAL: err})
return
results = []
while True:
val = queue.get()
if debug and val != MultiprocessingPoolTQDM.TQDM_UPDATE_SIGNAL:
print("Listener:", str(val))
if isinstance(val, list):
results.extend(val)
elif isinstance(val, str):
if val == MultiprocessingPoolTQDM.TQDM_UPDATE_SIGNAL:
pbar.update()
elif val == MultiprocessingPoolTQDM.ALL_COMPLETED_SENTINEL:
queue.put({MultiprocessingPoolTQDM.RESULT_SIGNAL: results})
break
else:
results.append(val)
else:
results.append(val)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
try:
if exc_type is None:
for proc in self.processes:
if proc is not None and proc.is_alive():
proc.join()
elif exc_type == tqdm.std.TqdmKeyError:
raise Exception("Invalid TQDM arguments...") from exc_val
else:
raise
finally:
self._close()
def _close(self):
if hasattr(self, 'processes') and isinstance(processes, Iterable):
for proc in self.processes:
if proc is not None and proc.is_alive():
proc.close()
if hasattr(self, 'queue') and self.queue is not None:
self.queue.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment