Skip to content

Instantly share code, notes, and snippets.

@jonashaag
Last active May 6, 2026 11:06
Show Gist options
  • Select an option

  • Save jonashaag/7a4d5627331c749f2b5b85869d9499e7 to your computer and use it in GitHub Desktop.

Select an option

Save jonashaag/7a4d5627331c749f2b5b85869d9499e7 to your computer and use it in GitHub Desktop.
Python ThreadPoolExecutor Work Stealing (Python 3.14)
import concurrent.futures.thread as _thread_impl
import threading
import time
import weakref
from concurrent.futures import Future
class WorkStealThreadPoolExecutor(_thread_impl.ThreadPoolExecutor):
"""A ThreadPoolExecutor that supports work stealing.
We use work stealing to prevent worker starvation.
We use a custom `WorkStealFuture` to support work stealing upon calling `future.result()`.
"""
def __init__(
self,
max_workers: int | None = None,
thread_name_prefix: str = "",
initializer: Callable[..., object] | None = None,
initargs: tuple[Any, ...] = (),
**ctxkwargs: Any,
):
# Reject any subclass override of prepare_context, even if the override is a
# trivial passthrough. Work stealing creates a fresh WorkerContext per stolen
# task and only the default WorkerContext semantics are safe in that path.
prepare_context = getattr(type(self).prepare_context, "__func__", None)
default_prepare_context = getattr(_thread_impl.ThreadPoolExecutor.prepare_context, "__func__", None)
uses_custom_worker_context = prepare_context is not default_prepare_context
uses_initializer = initializer is not None or initargs or ctxkwargs
if uses_initializer or uses_custom_worker_context:
raise TypeError("WorkStealThreadPoolExecutor does not support custom worker contexts")
super().__init__(max_workers=max_workers, thread_name_prefix=thread_name_prefix)
def _submit_wrapped(self, fn: Callable[P, R], /, *args: P.args, **kwargs: P.kwargs) -> "WorkStealFuture[R]":
# NOTE: Code below is almost completely copy-pasted from _thread_impl.ThreadPoolExecutor.submit.
with self._shutdown_lock, _thread_impl._global_shutdown_lock:
if self._broken:
raise _thread_impl.BrokenThreadPool(self._broken)
if self._shutdown:
raise RuntimeError("cannot schedule new futures after shutdown")
if _thread_impl._shutdown:
raise RuntimeError("cannot schedule new futures after interpreter shutdown")
f: WorkStealFuture[R] = WorkStealFuture()
task = self._resolve_work_item_task(fn, args, kwargs)
w = _thread_impl._WorkItem(f, task)
f._work_item_weakref = weakref.ref(w)
f._create_worker_context = self._create_worker_context
self._work_queue.put(w)
self._adjust_thread_count()
return f
def submit(self, fn: Callable[P, R], /, *args: P.args, **kwargs: P.kwargs) -> "WorkStealFuture[R]":
return cast(WorkStealFuture[R], super().submit(fn, *args, **kwargs))
def submit_fire_and_forget(self, fn: Callable[P, R], /, *args: P.args, **kwargs: P.kwargs) -> "WorkStealFuture[R]":
return cast(WorkStealFuture[R], super().submit_fire_and_forget(fn, *args, **kwargs))
@overload
def map[T1, R](
self,
fn: Callable[[T1], R],
iter1: Iterable[T1],
/,
*,
timeout: float | None = ...,
chunksize: int = ...,
buffersize: int | None = ...,
enable_work_stealing: bool = ...,
) -> Iterator[R]: ...
@overload
def map[T1, T2, R](
self,
fn: Callable[[T1, T2], R],
iter1: Iterable[T1],
iter2: Iterable[T2],
/,
*,
timeout: float | None = ...,
chunksize: int = ...,
buffersize: int | None = ...,
enable_work_stealing: bool = ...,
) -> Iterator[R]: ...
@overload
def map[T1, T2, T3, R](
self,
fn: Callable[[T1, T2, T3], R],
iter1: Iterable[T1],
iter2: Iterable[T2],
iter3: Iterable[T3],
/,
*,
timeout: float | None = ...,
chunksize: int = ...,
buffersize: int | None = ...,
enable_work_stealing: bool = ...,
) -> Iterator[R]: ...
@overload
def map[T1, T2, T3, T4, R](
self,
fn: Callable[[T1, T2, T3, T4], R],
iter1: Iterable[T1],
iter2: Iterable[T2],
iter3: Iterable[T3],
iter4: Iterable[T4],
/,
*,
timeout: float | None = ...,
chunksize: int = ...,
buffersize: int | None = ...,
enable_work_stealing: bool = ...,
) -> Iterator[R]: ...
def map(
self,
fn: Callable[..., R],
*iterables: Iterable[Any],
timeout: float | None = None,
chunksize: int = 1,
buffersize: int | None = None,
enable_work_stealing: bool = True,
) -> Iterator[R]:
if buffersize is not None and buffersize < 1:
raise ValueError("buffersize must be None or > 0")
# Synced with CPython 3.14.4 ThreadPoolExecutor.map, with local work-stealing
# and context-propagation changes.
end_time = (timeout or 0) + time.monotonic()
zipped_iterables = zip(*iterables)
if buffersize:
fs = deque(self.submit(fn, *args) for args in islice(zipped_iterables, buffersize))
else:
args_lists = list(zipped_iterables)
if enable_work_stealing and timeout is None and len(args_lists) < 2:
# Ensure context propagation even for inline execution
wrapped_fn = self._wrap_with_context(fn, log_exceptions=False)
return (wrapped_fn(*args) for args in args_lists)
fs = [self.submit(fn, *args) for args in args_lists]
executor_weakref = weakref.ref(self)
def result_iterator() -> Iterator[R]:
try:
fs.reverse()
while fs:
current_future = [fs.pop()]
if buffersize and isinstance(fs, deque) and (executor := executor_weakref()) is not None:
# Refill after popping but before yielding, so caller processing time does not starve the buffer.
args = next(zipped_iterables, None)
if args is not None:
fs.appendleft(executor.submit(fn, *args))
yield _result_or_cancel(
current_future.pop(),
timeout=None if timeout is None else end_time - time.monotonic(),
enable_work_stealing=enable_work_stealing,
)
finally:
for future in fs:
future.cancel()
return result_iterator()
class WorkStealFuture[T](Future[T]):
"""A `Future` that supports work stealing, for use in `WorkStealThreadPoolExecutor`."""
_work_item_weakref: weakref.ref[_thread_impl._WorkItem]
_create_worker_context: Callable[[], _thread_impl.WorkerContext]
def __init__(self):
super().__init__()
self._lock = threading.Lock()
self._starting_thread_id = None
def result(self, timeout=None, enable_work_stealing=True):
"""Get the result of the future.
If enable_work_stealing is True, attempt to execute the task itself instead of waiting.
"""
if not enable_work_stealing:
return super().result(timeout)
if timeout is not None and timeout > 0:
raise NotImplementedError("Positive 'timeout' not supported with 'enable_work_stealing=True'")
try:
with self._lock:
if self._starting_thread_id:
should_start = False
elif timeout is not None and timeout <= 0:
raise TimeoutError
else:
should_start = True
self._starting_thread_id = threading.get_ident()
if should_start and (work_item := self._work_item_weakref()):
try:
worker_context = self._create_worker_context()
except BaseException as e:
self.set_exception(e)
raise
work_item.run(worker_context)
return super().result(timeout)
finally:
# Break a reference cycle with the exception in self._exception.
self = None
def set_running_or_notify_cancel(self):
with self._lock:
if self._starting_thread_id and (self._starting_thread_id != threading.get_ident()) or (self._state != "PENDING"):
return False
self._starting_thread_id = threading.get_ident()
return super().set_running_or_notify_cancel()
def _result_or_cancel[R](fut: WorkStealFuture[R], *, timeout: float | None, enable_work_stealing: bool) -> R:
try:
try:
return fut.result(timeout, enable_work_stealing)
finally:
fut.cancel()
finally:
# Break a reference cycle with the exception in self._exception
del fut
@jonashaag
Copy link
Copy Markdown
Author

Note: battle tested with millions of executions over the span of many months.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment