Last active
January 8, 2021 14:40
-
-
Save martin-kokos/fa22562fa30cec304bb7b84a3de4e394 to your computer and use it in GitHub Desktop.
Simple multiprocessing task manager
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
If the tasks are inequally sized, it may happen with starmap and similar schedulers | |
which schedule tasks ahead of time, that some workers are lucky and finish early and some unlucky | |
and finish much later having the effect of the task batch using all CPU cores at first, | |
but then lucky workers finishing and core sitting idle. | |
This task manager serves as an example on how to have workers fetch tasks from a Queue. | |
Joblib, Pool.starmap or something else might be better, but maybe this is useful. | |
''' | |
import collections | |
import functools | |
import io | |
import os | |
import sys | |
import time | |
import multiprocessing as mp | |
from queue import Empty | |
import tqdm | |
def worker_wrap(func, task_queue, result_queue, max_tasks): | |
# try avoiding logging module deadlock | |
sys.stderr = io.TextIOWrapper(open(f'/tmp/worker-{os.getpid()}.log', 'wb', 0), write_through=True) | |
for _ in range(max_tasks): | |
# acquire work | |
try: | |
work_kwargs = task_queue.get_nowait() | |
except Empty: | |
break | |
# do work | |
result = func(**work_kwargs) | |
# submit result | |
result_queue.put_nowait(result) | |
def task_manager(tasks, func, workers, max_tasks): | |
''' | |
Simple task manager | |
Example usage: | |
tasks = range(14) | |
def func(work): | |
res = work * 10 | |
time.sleep(6) | |
task_manager( | |
tasks=tasks, | |
func=func, | |
workers=4, | |
max_tasks=2, | |
) | |
''' | |
task_queue = mp.Queue() | |
result_queue = mp.Queue() | |
workers_exited = [] | |
for t in tasks: | |
task_queue.put(t) | |
running_procs = [] | |
progress_bar = tqdm.tqdm(total=task_queue.qsize(), desc=f'{workers} workers') | |
new_proc = functools.partial( | |
mp.Process, | |
target=worker_wrap, args=(func, task_queue, result_queue, max_tasks), | |
) | |
# Maintenance loop | |
while not task_queue.empty(): | |
progress_bar.update(result_queue.qsize() - progress_bar.n) | |
# Collect exited and crashed workers | |
for p in running_procs: | |
if p.exitcode: | |
workers_exited.append(p.exitcode) | |
p.join() | |
running_procs = [p for p in running_procs if p.is_alive()] | |
# Spawn new | |
missing = workers - len(running_procs) | |
for _ in range(missing): | |
new_p = new_proc() | |
new_p.start() | |
running_procs.append(new_p) | |
time.sleep(1) | |
# Join workers | |
print('Submitted last task. Waiting for workers to finish') | |
for p in running_procs: | |
p.join() | |
progress_bar.update(result_queue.qsize() - progress_bar.n) | |
progress_bar.close() | |
workers_exited = dict(collections.Counter(workers_exited)) | |
print(f'Worker exit codes: {workers_exited}') | |
# Collect results | |
print('Collecting results') | |
results = [] | |
while not result_queue.empty(): | |
results.append(result_queue.get()) | |
print('Scheduled tasks done') | |
return results | |
if __name__ == "__main__": | |
tasks = [{'work': i} for i in range(14)] | |
def func(work): | |
res = work * 10 | |
time.sleep(6) | |
return res | |
results = task_manager( | |
tasks=tasks, | |
func=func, | |
workers=4, | |
max_tasks=2, | |
) | |
print(f'Results: {results}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment