Last active
November 16, 2023 00:19
-
-
Save Klotzi111/9ab06b0380702cd5f4044c7529bdc096 to your computer and use it in GitHub Desktop.
Python CombinedFuture
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import concurrent | |
from concurrent.futures import Future | |
from typing import Any | |
class CombinedFuture(Future[Future | None]): | |
""" | |
This class provides "waiting" mechanisms similar to concurrent.futures.wait(...) except that there is no blocking wait. | |
This class extends concurrent.futures.Future and thus it can be used like any other Future. | |
You can use the .result() and .done() (and other) methods and also use this class with the aforementioned concurrent.futures.wait function. | |
This class is especially useful when you want to combine multiple futures with and (&&) and or (||) logic. | |
Example: | |
Consider you have multiple parallel tasks (as futures) and a future that will be completed once your function should return (canellation_future). | |
You want to wait until all tasks finish normally or the canellation_future is completed. | |
With the standard python library this is not possible because concurrent.futures.wait(...) can either wait for all futures or one. | |
Using ALL_COMPLETED will never work. And FIRST_COMPLETED would return also if only one task_futures was completed. | |
The following code uses CombinedFuture to solve this problem. | |
.. code-block:: python | |
def create_task() -> Future: | |
# TODO add logic that completes this future | |
return Future() | |
# can be completed any time | |
cancellation_future = Future() | |
task_futures = [create_task(), create_task()] | |
task_combined_future = CombinedFuture(*task_futures, complete_when=concurrent.futures.ALL_COMPLETED) | |
done, not_done = concurrent.futures.wait([cancellation_future, task_combined_future], timeout=None, return_when=concurrent.futures.ALL_COMPLETED) | |
if cancellation_future in done: | |
print("cancellation_future was completed") | |
else: | |
print("task_combined_future was completed") | |
""" | |
def __init__(self, *futures: Future, complete_when : int = concurrent.futures.FIRST_COMPLETED) -> None: | |
self.complete_when = complete_when | |
self.futures = set(futures) | |
self.completed_futures = set() | |
super().__init__() | |
for future in self.futures: | |
future.add_done_callback(self._future_completed_callback) | |
def _set_result_safe(self, result: Any): | |
try: | |
self.set_result(result) | |
except: | |
# this might happen when the future had its result already set | |
# this can happen when: | |
# a second future completes or multiple at "the same time" | |
# or the user called set_result or changed the complete_when attribute. both is not supported | |
pass | |
def _future_completed_callback(self, future: Future) -> None: | |
self.completed_futures.add(future) | |
if self.complete_when == concurrent.futures.FIRST_COMPLETED: | |
# no count check required because we only need one and we just added our future | |
self._set_result_safe(future) | |
return | |
elif self.complete_when == concurrent.futures.FIRST_EXCEPTION: | |
if future.exception(timeout=0) is not None: | |
# future completed with exception | |
self._set_result_safe(future) | |
# else: should be concurrent.futures.ALL_COMPLETED | |
# but we also want this logic in the FIRST_EXCEPTION case | |
if self.completed_futures == self.futures: | |
self._set_result_safe(None) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Unit tests for CombinedFuture""" | |
import unittest | |
import concurrent | |
from concurrent.futures import CancelledError, Future | |
from combined_future import CombinedFuture | |
class TestCombinedFuture(unittest.TestCase): | |
def helper_check_combined_future_completed(self, combined_future : CombinedFuture, resulting_future : Future | None) -> None: | |
# CombinedFuture will never be cancelled unless user calls the cancel function | |
self.assertFalse(combined_future.cancelled()) | |
self.assertTrue(combined_future.done()) | |
# CombinedFuture will never set the exception even if the future completes with an exception | |
self.assertEqual(combined_future.exception(timeout=0), None) | |
completed_future = combined_future.result(timeout=0) | |
self.assertEqual(completed_future, resulting_future) | |
def test_first_completed_single_future_result(self) -> None: | |
future = Future() | |
combined_future = CombinedFuture(future, complete_when=concurrent.futures.FIRST_COMPLETED) | |
self.assertFalse(combined_future.done()) | |
future.set_result("done") | |
self.helper_check_combined_future_completed(combined_future, future) | |
def test_first_completed_single_future_cancel(self) -> None: | |
future = Future() | |
combined_future = CombinedFuture(future, complete_when=concurrent.futures.FIRST_COMPLETED) | |
self.assertFalse(combined_future.done()) | |
self.assertTrue(future.cancel()) | |
self.helper_check_combined_future_completed(combined_future, future) | |
self.assertTrue(future.cancelled()) | |
with self.assertRaises(CancelledError): | |
future.result() | |
def helper_single_future_exception(self, complete_when : int) -> None: | |
future = Future() | |
combined_future = CombinedFuture(future, complete_when=complete_when) | |
self.assertFalse(combined_future.done()) | |
future.set_exception(Exception("exception")) | |
self.helper_check_combined_future_completed(combined_future, future) | |
with self.assertRaises(Exception): | |
future.result() | |
def test_first_completed_single_future_exception(self) -> None: | |
self.helper_single_future_exception(concurrent.futures.FIRST_COMPLETED) | |
def test_first_exception_single_future_exception(self) -> None: | |
self.helper_single_future_exception(concurrent.futures.FIRST_EXCEPTION) | |
def helper_prepare_future_list(self) -> list[Future]: | |
future_count = 3 | |
futures = [] | |
for i in range(0, future_count): | |
futures.append(Future()) | |
return futures | |
def helper_prepare_combined_future_with_multiple_futures(self, complete_when : int) -> (CombinedFuture, list[Future]): | |
futures = self.helper_prepare_future_list() | |
combined_future = CombinedFuture(*futures, complete_when=complete_when) | |
self.assertFalse(combined_future.done()) | |
return (combined_future, futures) | |
def test_first_completed_multiple_futures_result(self) -> None: | |
combined_future, futures = self.helper_prepare_combined_future_with_multiple_futures(concurrent.futures.FIRST_COMPLETED) | |
future = futures[0] | |
future.set_result("done") | |
# one future is done. Check that CombinedFuture did complete | |
self.helper_check_combined_future_completed(combined_future, future) | |
def helper_check_completed_when_all_futures_complete(self, complete_when : int) -> None: | |
combined_future, futures = self.helper_prepare_combined_future_with_multiple_futures(complete_when=complete_when) | |
future_index = 0 | |
for future in futures: | |
future.set_result(f"done{future_index}") | |
future_index += 1 | |
if future_index != len(futures): | |
# not all futures done yet. Check that CombinedFuture did not complete | |
self.assertFalse(combined_future.done()) | |
# now all futures are done. Check that CombinedFuture did complete | |
self.helper_check_combined_future_completed(combined_future, None) | |
def test_all_completed_multiple_futures_result(self) -> None: | |
self.helper_check_completed_when_all_futures_complete(complete_when=concurrent.futures.ALL_COMPLETED) | |
def test_first_exception_multiple_futures_no_exception(self) -> None: | |
self.helper_check_completed_when_all_futures_complete(complete_when=concurrent.futures.FIRST_EXCEPTION) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment