Last active
May 6, 2026 11:06
-
-
Save jonashaag/7a4d5627331c749f2b5b85869d9499e7 to your computer and use it in GitHub Desktop.
Python ThreadPoolExecutor Work Stealing (Python 3.14)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import concurrent.futures.thread as _thread_impl | |
| import threading | |
| import time | |
| import weakref | |
| from concurrent.futures import Future | |
| class WorkStealThreadPoolExecutor(_thread_impl.ThreadPoolExecutor): | |
| """A ThreadPoolExecutor that supports work stealing. | |
| We use work stealing to prevent worker starvation. | |
| We use a custom `WorkStealFuture` to support work stealing upon calling `future.result()`. | |
| """ | |
| def __init__( | |
| self, | |
| max_workers: int | None = None, | |
| thread_name_prefix: str = "", | |
| initializer: Callable[..., object] | None = None, | |
| initargs: tuple[Any, ...] = (), | |
| **ctxkwargs: Any, | |
| ): | |
| # Reject any subclass override of prepare_context, even if the override is a | |
| # trivial passthrough. Work stealing creates a fresh WorkerContext per stolen | |
| # task and only the default WorkerContext semantics are safe in that path. | |
| prepare_context = getattr(type(self).prepare_context, "__func__", None) | |
| default_prepare_context = getattr(_thread_impl.ThreadPoolExecutor.prepare_context, "__func__", None) | |
| uses_custom_worker_context = prepare_context is not default_prepare_context | |
| uses_initializer = initializer is not None or initargs or ctxkwargs | |
| if uses_initializer or uses_custom_worker_context: | |
| raise TypeError("WorkStealThreadPoolExecutor does not support custom worker contexts") | |
| super().__init__(max_workers=max_workers, thread_name_prefix=thread_name_prefix) | |
| def _submit_wrapped(self, fn: Callable[P, R], /, *args: P.args, **kwargs: P.kwargs) -> "WorkStealFuture[R]": | |
| # NOTE: Code below is almost completely copy-pasted from _thread_impl.ThreadPoolExecutor.submit. | |
| with self._shutdown_lock, _thread_impl._global_shutdown_lock: | |
| if self._broken: | |
| raise _thread_impl.BrokenThreadPool(self._broken) | |
| if self._shutdown: | |
| raise RuntimeError("cannot schedule new futures after shutdown") | |
| if _thread_impl._shutdown: | |
| raise RuntimeError("cannot schedule new futures after interpreter shutdown") | |
| f: WorkStealFuture[R] = WorkStealFuture() | |
| task = self._resolve_work_item_task(fn, args, kwargs) | |
| w = _thread_impl._WorkItem(f, task) | |
| f._work_item_weakref = weakref.ref(w) | |
| f._create_worker_context = self._create_worker_context | |
| self._work_queue.put(w) | |
| self._adjust_thread_count() | |
| return f | |
| def submit(self, fn: Callable[P, R], /, *args: P.args, **kwargs: P.kwargs) -> "WorkStealFuture[R]": | |
| return cast(WorkStealFuture[R], super().submit(fn, *args, **kwargs)) | |
| def submit_fire_and_forget(self, fn: Callable[P, R], /, *args: P.args, **kwargs: P.kwargs) -> "WorkStealFuture[R]": | |
| return cast(WorkStealFuture[R], super().submit_fire_and_forget(fn, *args, **kwargs)) | |
| @overload | |
| def map[T1, R]( | |
| self, | |
| fn: Callable[[T1], R], | |
| iter1: Iterable[T1], | |
| /, | |
| *, | |
| timeout: float | None = ..., | |
| chunksize: int = ..., | |
| buffersize: int | None = ..., | |
| enable_work_stealing: bool = ..., | |
| ) -> Iterator[R]: ... | |
| @overload | |
| def map[T1, T2, R]( | |
| self, | |
| fn: Callable[[T1, T2], R], | |
| iter1: Iterable[T1], | |
| iter2: Iterable[T2], | |
| /, | |
| *, | |
| timeout: float | None = ..., | |
| chunksize: int = ..., | |
| buffersize: int | None = ..., | |
| enable_work_stealing: bool = ..., | |
| ) -> Iterator[R]: ... | |
| @overload | |
| def map[T1, T2, T3, R]( | |
| self, | |
| fn: Callable[[T1, T2, T3], R], | |
| iter1: Iterable[T1], | |
| iter2: Iterable[T2], | |
| iter3: Iterable[T3], | |
| /, | |
| *, | |
| timeout: float | None = ..., | |
| chunksize: int = ..., | |
| buffersize: int | None = ..., | |
| enable_work_stealing: bool = ..., | |
| ) -> Iterator[R]: ... | |
| @overload | |
| def map[T1, T2, T3, T4, R]( | |
| self, | |
| fn: Callable[[T1, T2, T3, T4], R], | |
| iter1: Iterable[T1], | |
| iter2: Iterable[T2], | |
| iter3: Iterable[T3], | |
| iter4: Iterable[T4], | |
| /, | |
| *, | |
| timeout: float | None = ..., | |
| chunksize: int = ..., | |
| buffersize: int | None = ..., | |
| enable_work_stealing: bool = ..., | |
| ) -> Iterator[R]: ... | |
| def map( | |
| self, | |
| fn: Callable[..., R], | |
| *iterables: Iterable[Any], | |
| timeout: float | None = None, | |
| chunksize: int = 1, | |
| buffersize: int | None = None, | |
| enable_work_stealing: bool = True, | |
| ) -> Iterator[R]: | |
| if buffersize is not None and buffersize < 1: | |
| raise ValueError("buffersize must be None or > 0") | |
| # Synced with CPython 3.14.4 ThreadPoolExecutor.map, with local work-stealing | |
| # and context-propagation changes. | |
| end_time = (timeout or 0) + time.monotonic() | |
| zipped_iterables = zip(*iterables) | |
| if buffersize: | |
| fs = deque(self.submit(fn, *args) for args in islice(zipped_iterables, buffersize)) | |
| else: | |
| args_lists = list(zipped_iterables) | |
| if enable_work_stealing and timeout is None and len(args_lists) < 2: | |
| # Ensure context propagation even for inline execution | |
| wrapped_fn = self._wrap_with_context(fn, log_exceptions=False) | |
| return (wrapped_fn(*args) for args in args_lists) | |
| fs = [self.submit(fn, *args) for args in args_lists] | |
| executor_weakref = weakref.ref(self) | |
| def result_iterator() -> Iterator[R]: | |
| try: | |
| fs.reverse() | |
| while fs: | |
| current_future = [fs.pop()] | |
| if buffersize and isinstance(fs, deque) and (executor := executor_weakref()) is not None: | |
| # Refill after popping but before yielding, so caller processing time does not starve the buffer. | |
| args = next(zipped_iterables, None) | |
| if args is not None: | |
| fs.appendleft(executor.submit(fn, *args)) | |
| yield _result_or_cancel( | |
| current_future.pop(), | |
| timeout=None if timeout is None else end_time - time.monotonic(), | |
| enable_work_stealing=enable_work_stealing, | |
| ) | |
| finally: | |
| for future in fs: | |
| future.cancel() | |
| return result_iterator() | |
| class WorkStealFuture[T](Future[T]): | |
| """A `Future` that supports work stealing, for use in `WorkStealThreadPoolExecutor`.""" | |
| _work_item_weakref: weakref.ref[_thread_impl._WorkItem] | |
| _create_worker_context: Callable[[], _thread_impl.WorkerContext] | |
| def __init__(self): | |
| super().__init__() | |
| self._lock = threading.Lock() | |
| self._starting_thread_id = None | |
| def result(self, timeout=None, enable_work_stealing=True): | |
| """Get the result of the future. | |
| If enable_work_stealing is True, attempt to execute the task itself instead of waiting. | |
| """ | |
| if not enable_work_stealing: | |
| return super().result(timeout) | |
| if timeout is not None and timeout > 0: | |
| raise NotImplementedError("Positive 'timeout' not supported with 'enable_work_stealing=True'") | |
| try: | |
| with self._lock: | |
| if self._starting_thread_id: | |
| should_start = False | |
| elif timeout is not None and timeout <= 0: | |
| raise TimeoutError | |
| else: | |
| should_start = True | |
| self._starting_thread_id = threading.get_ident() | |
| if should_start and (work_item := self._work_item_weakref()): | |
| try: | |
| worker_context = self._create_worker_context() | |
| except BaseException as e: | |
| self.set_exception(e) | |
| raise | |
| work_item.run(worker_context) | |
| return super().result(timeout) | |
| finally: | |
| # Break a reference cycle with the exception in self._exception. | |
| self = None | |
| def set_running_or_notify_cancel(self): | |
| with self._lock: | |
| if self._starting_thread_id and (self._starting_thread_id != threading.get_ident()) or (self._state != "PENDING"): | |
| return False | |
| self._starting_thread_id = threading.get_ident() | |
| return super().set_running_or_notify_cancel() | |
| def _result_or_cancel[R](fut: WorkStealFuture[R], *, timeout: float | None, enable_work_stealing: bool) -> R: | |
| try: | |
| try: | |
| return fut.result(timeout, enable_work_stealing) | |
| finally: | |
| fut.cancel() | |
| finally: | |
| # Break a reference cycle with the exception in self._exception | |
| del fut |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Note: battle tested with millions of executions over the span of many months.