Created
January 10, 2020 11:05
Star
You must be signed in to star a gist
Python ThreadPool with limited number of threads - adapted from https://github.com/cirosantilli/linux-kernel-module-cheat/blob/0486765421f5ad7ab21c6714cdc80e0d17458e9c/thread_pool.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from typing import Any, Callable, Dict, Iterable, Union | |
import multiprocessing | |
import queue | |
import sys | |
import threading | |
import time | |
class ThreadPool: | |
''' | |
Start a pool of a limited number of threads to do some work. | |
This is similar to the stdlib concurrent, but I could not find | |
how to reach all my design goals with that implementation: | |
- the input function does not need to be modified | |
- limit the number of threads | |
- queue sizes closely follow number of threads | |
- if an exception happens, optionally stop soon afterwards | |
Functional form and further discussion at: | |
https://stackoverflow.com/questions/19369724/the-right-way-to-limit-maximum-number-of-threads-running-at-once/55263676#55263676 | |
This class form allows to use your own while loops with submit(). | |
Quick test with: | |
./thread_limit.py 2 -10 20 0 | |
./thread_limit.py 2 -10 20 1 | |
./thread_limit.py 2 -10 20 2 | |
./thread_limit.py 2 -10 20 3 | |
These ensure that execution stops neatly on error. | |
''' | |
def __init__( | |
self, | |
func: Callable, | |
handle_output: Union[Callable[[Any,Any,Exception],Any],None] = None, | |
nthreads: Union[int,None] = None | |
): | |
''' | |
Start in a thread pool immediately. | |
join() must be called afterwards at some point. | |
:param func: main work function to be evaluated. | |
:param handle_output: called on func return values as they | |
are returned. | |
Signature is: handle_output(input, output, exception) where: | |
- input: input given to func | |
- output: return value of func | |
- exception: the exception that func raised, or None otherwise | |
If this function returns non-None or raises, stop feeding | |
new input and exit ASAP when all currently running threads | |
have finished. | |
Default: a handler that does nothing and just exits on exception. | |
:param nthreads: number of threads to use. Default: nproc. | |
''' | |
self.func = func | |
if handle_output is None: | |
handle_output = lambda input, output, exception: exception | |
self.handle_output = handle_output | |
if nthreads is None: | |
nthreads = multiprocessing.cpu_count() | |
self.nthreads = nthreads | |
self.error_output = None | |
self.error_output_lock = threading.Lock() | |
self.in_queue = queue.Queue(maxsize=nthreads) | |
self.threads = [] | |
for i in range(self.nthreads): | |
thread = threading.Thread( | |
target=self._func_runner, | |
) | |
self.threads.append(thread) | |
thread.start() | |
def submit(self, work): | |
''' | |
Submit work. Block if there is already enough work scheduled (~nthreads). | |
:return: if an error occurred in some previously executed thread, the error. | |
Otherwise, None. This allows the caller to stop submitting further | |
work if desired. | |
''' | |
self.in_queue.put(work) | |
return self.error_output | |
def join(self): | |
''' | |
Request all threads to stop after they finish currently submitted work. | |
:return: same as submit() | |
''' | |
for thread in range(self.nthreads): | |
self.in_queue.put(None) | |
for thread in self.threads: | |
thread.join() | |
return self.error_output | |
def _func_runner(self): | |
while True: | |
work = self.in_queue.get(block=True) | |
if work is None: | |
break | |
try: | |
exception = None | |
out = self.func(**work) | |
except Exception as e: | |
exception = e | |
try: | |
handle_output_return = self.handle_output(work, out, exception) | |
except Exception as e: | |
self.error_output_lock.acquire() | |
self.error_output = (work, out, e) | |
self.error_output_lock.release() | |
else: | |
if handle_output_return is not None: | |
self.error_output_lock.acquire() | |
self.error_output = handle_output_return | |
self.error_output_lock.release() | |
finally: | |
self.in_queue.task_done() | |
if __name__ == '__main__': | |
def my_func(i): | |
''' | |
The main function that will be evaluated. | |
It sleeps to simulate an IO operation. | |
''' | |
time.sleep((abs(i) % 4) / 10.0) | |
return 10.0 / i | |
def get_work(min_, max_): | |
''' | |
Generate simple range work for my_func. | |
''' | |
for i in range(min_, max_): | |
yield {'i': i} | |
def handle_output_print(input, output, exception): | |
''' | |
Print outputs and exit immeditaly on failure. | |
''' | |
print('{!r} {!r} {!r}'.format(input, output, exception)) | |
return exception | |
def handle_output_print_no_exit(input, output, exception): | |
''' | |
Print outputs, don't exit on failure. | |
''' | |
print('{!r} {!r} {!r}'.format(input, output, exception)) | |
out_queue = queue.Queue() | |
def handle_output_queue(input, output, exception): | |
''' | |
Store outputs in a queue for later usage. | |
''' | |
global out_queue | |
out_queue.put((input, output, exception)) | |
return exception | |
def handle_output_raise(input, output, exception): | |
''' | |
Raise if input == 10, to test that execution | |
stops nicely if this raises. | |
''' | |
print('{!r} {!r} {!r}'.format(input, output, exception)) | |
if input['i'] == 10: | |
raise Exception | |
# CLI arguments. | |
argv_len = len(sys.argv) | |
if argv_len > 1: | |
nthreads = int(sys.argv[1]) | |
if nthreads == 0: | |
nthreads = None | |
else: | |
nthreads = None | |
if argv_len > 2: | |
min_ = int(sys.argv[2]) | |
else: | |
min_ = 1 | |
if argv_len > 3: | |
max_ = int(sys.argv[3]) | |
else: | |
max_ = 100 | |
if argv_len > 4: | |
c = sys.argv[4][0] | |
else: | |
c = '0' | |
if c == '1': | |
handle_output = handle_output_print_no_exit | |
elif c == '2': | |
handle_output = handle_output_queue | |
elif c == '3': | |
handle_output = handle_output_raise | |
else: | |
handle_output = handle_output_print | |
# Action. | |
thread_pool = ThreadPool( | |
my_func, | |
handle_output, | |
nthreads | |
) | |
for work in get_work(min_, max_): | |
error = thread_pool.submit(work) | |
if error is not None: | |
break | |
error = thread_pool.join() | |
if error is not None: | |
print('error: {!r}'.format(error)) | |
if handle_output == handle_output_queue: | |
while not out_queue.empty(): | |
print(out_queue.get()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment