Last active
May 17, 2024 13:23
-
-
Save thegamecracks/395dfa7e12ccd1e96982630bca85ca65 to your computer and use it in GitHub Desktop.
Wrapping concurrent.futures.Future to support .map() method chaining
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from concurrent.futures import Future, InvalidStateError, ThreadPoolExecutor | |
from contextlib import suppress | |
from typing import Any, Callable, TypeVar | |
T = TypeVar("T") | |
R_co = TypeVar("R_co", covariant=True) | |
def main() -> None: | |
with ThreadPoolExecutor(max_workers=1) as executor: | |
print( | |
f_wrap(executor.submit(lambda: 123)) | |
.map(lambda n: (n,)) | |
.map(lambda n: {n}) | |
.map(lambda n: [n]) | |
.result() | |
) | |
# Result: [{(123,)}] | |
def f_wrap(fut: Future[T]) -> MapFuture[T]: | |
wrapped = MapFuture[T]() | |
_chain_future(lambda t: t, fut, wrapped) | |
return wrapped | |
class MapFuture(Future[T]): | |
def map(self, func: Callable[[T], R_co]) -> MapFuture[R_co]: | |
wrapped = MapFuture[R_co]() | |
_chain_future(func, self, wrapped) | |
return wrapped | |
def _chain_future( | |
func: Callable[[T], R_co], | |
source: Future[T], | |
dest: Future[R_co], | |
) -> None: | |
def check_source_result(source: Future[T]) -> None: | |
if source.cancelled(): | |
dest.cancel() | |
return | |
if (exc := source.exception()) is not None: | |
_maybe_set_exception(dest, exc) | |
return | |
try: | |
result = func(source.result()) | |
_maybe_set_result(dest, result) | |
except Exception as e: | |
_maybe_set_exception(dest, e) | |
def check_dest_cancelled(dest: Future[R_co]) -> None: | |
if dest.cancelled(): | |
source.cancel() | |
return | |
source.add_done_callback(check_source_result) | |
dest.add_done_callback(check_dest_cancelled) | |
def _maybe_set_exception(fut: Future[Any], exc: BaseException) -> None: | |
with suppress(InvalidStateError): | |
fut.set_exception(exc) | |
def _maybe_set_result(fut: Future[T], result: T) -> None: | |
with suppress(InvalidStateError): | |
fut.set_result(result) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
References:
https://docs.rs/futures/latest/futures/future/trait.FutureExt.html#method.map
https://github.com/python/cpython/blob/v3.12.3/Lib/asyncio/futures.py#L365-L406
https://github.com/rohanpm/more-executors/blob/v2.11.4/more_executors/_impl/map.py#L18-L114