Skip to content

Instantly share code, notes, and snippets.

@Klotzi111
Last active November 16, 2023 00:19
Show Gist options
  • Save Klotzi111/9ab06b0380702cd5f4044c7529bdc096 to your computer and use it in GitHub Desktop.
Save Klotzi111/9ab06b0380702cd5f4044c7529bdc096 to your computer and use it in GitHub Desktop.
Python CombinedFuture
import concurrent
from concurrent.futures import Future
from typing import Any
class CombinedFuture(Future[Future | None]):
"""
This class provides "waiting" mechanisms similar to concurrent.futures.wait(...) except that there is no blocking wait.
This class extends concurrent.futures.Future and thus it can be used like any other Future.
You can use the .result() and .done() (and other) methods and also use this class with the aforementioned concurrent.futures.wait function.
This class is especially useful when you want to combine multiple futures with and (&&) and or (||) logic.
Example:
Consider you have multiple parallel tasks (as futures) and a future that will be completed once your function should return (canellation_future).
You want to wait until all tasks finish normally or the canellation_future is completed.
With the standard python library this is not possible because concurrent.futures.wait(...) can either wait for all futures or one.
Using ALL_COMPLETED will never work. And FIRST_COMPLETED would return also if only one task_futures was completed.
The following code uses CombinedFuture to solve this problem.
.. code-block:: python
def create_task() -> Future:
# TODO add logic that completes this future
return Future()
# can be completed any time
cancellation_future = Future()
task_futures = [create_task(), create_task()]
task_combined_future = CombinedFuture(*task_futures, complete_when=concurrent.futures.ALL_COMPLETED)
done, not_done = concurrent.futures.wait([cancellation_future, task_combined_future], timeout=None, return_when=concurrent.futures.ALL_COMPLETED)
if cancellation_future in done:
print("cancellation_future was completed")
else:
print("task_combined_future was completed")
"""
def __init__(self, *futures: Future, complete_when : int = concurrent.futures.FIRST_COMPLETED) -> None:
self.complete_when = complete_when
self.futures = set(futures)
self.completed_futures = set()
super().__init__()
for future in self.futures:
future.add_done_callback(self._future_completed_callback)
def _set_result_safe(self, result: Any):
try:
self.set_result(result)
except:
# this might happen when the future had its result already set
# this can happen when:
# a second future completes or multiple at "the same time"
# or the user called set_result or changed the complete_when attribute. both is not supported
pass
def _future_completed_callback(self, future: Future) -> None:
self.completed_futures.add(future)
if self.complete_when == concurrent.futures.FIRST_COMPLETED:
# no count check required because we only need one and we just added our future
self._set_result_safe(future)
return
elif self.complete_when == concurrent.futures.FIRST_EXCEPTION:
if future.exception(timeout=0) is not None:
# future completed with exception
self._set_result_safe(future)
# else: should be concurrent.futures.ALL_COMPLETED
# but we also want this logic in the FIRST_EXCEPTION case
if self.completed_futures == self.futures:
self._set_result_safe(None)
"""Unit tests for CombinedFuture"""
import unittest
import concurrent
from concurrent.futures import CancelledError, Future
from combined_future import CombinedFuture
class TestCombinedFuture(unittest.TestCase):
def helper_check_combined_future_completed(self, combined_future : CombinedFuture, resulting_future : Future | None) -> None:
# CombinedFuture will never be cancelled unless user calls the cancel function
self.assertFalse(combined_future.cancelled())
self.assertTrue(combined_future.done())
# CombinedFuture will never set the exception even if the future completes with an exception
self.assertEqual(combined_future.exception(timeout=0), None)
completed_future = combined_future.result(timeout=0)
self.assertEqual(completed_future, resulting_future)
def test_first_completed_single_future_result(self) -> None:
future = Future()
combined_future = CombinedFuture(future, complete_when=concurrent.futures.FIRST_COMPLETED)
self.assertFalse(combined_future.done())
future.set_result("done")
self.helper_check_combined_future_completed(combined_future, future)
def test_first_completed_single_future_cancel(self) -> None:
future = Future()
combined_future = CombinedFuture(future, complete_when=concurrent.futures.FIRST_COMPLETED)
self.assertFalse(combined_future.done())
self.assertTrue(future.cancel())
self.helper_check_combined_future_completed(combined_future, future)
self.assertTrue(future.cancelled())
with self.assertRaises(CancelledError):
future.result()
def helper_single_future_exception(self, complete_when : int) -> None:
future = Future()
combined_future = CombinedFuture(future, complete_when=complete_when)
self.assertFalse(combined_future.done())
future.set_exception(Exception("exception"))
self.helper_check_combined_future_completed(combined_future, future)
with self.assertRaises(Exception):
future.result()
def test_first_completed_single_future_exception(self) -> None:
self.helper_single_future_exception(concurrent.futures.FIRST_COMPLETED)
def test_first_exception_single_future_exception(self) -> None:
self.helper_single_future_exception(concurrent.futures.FIRST_EXCEPTION)
def helper_prepare_future_list(self) -> list[Future]:
future_count = 3
futures = []
for i in range(0, future_count):
futures.append(Future())
return futures
def helper_prepare_combined_future_with_multiple_futures(self, complete_when : int) -> (CombinedFuture, list[Future]):
futures = self.helper_prepare_future_list()
combined_future = CombinedFuture(*futures, complete_when=complete_when)
self.assertFalse(combined_future.done())
return (combined_future, futures)
def test_first_completed_multiple_futures_result(self) -> None:
combined_future, futures = self.helper_prepare_combined_future_with_multiple_futures(concurrent.futures.FIRST_COMPLETED)
future = futures[0]
future.set_result("done")
# one future is done. Check that CombinedFuture did complete
self.helper_check_combined_future_completed(combined_future, future)
def helper_check_completed_when_all_futures_complete(self, complete_when : int) -> None:
combined_future, futures = self.helper_prepare_combined_future_with_multiple_futures(complete_when=complete_when)
future_index = 0
for future in futures:
future.set_result(f"done{future_index}")
future_index += 1
if future_index != len(futures):
# not all futures done yet. Check that CombinedFuture did not complete
self.assertFalse(combined_future.done())
# now all futures are done. Check that CombinedFuture did complete
self.helper_check_combined_future_completed(combined_future, None)
def test_all_completed_multiple_futures_result(self) -> None:
self.helper_check_completed_when_all_futures_complete(complete_when=concurrent.futures.ALL_COMPLETED)
def test_first_exception_multiple_futures_no_exception(self) -> None:
self.helper_check_completed_when_all_futures_complete(complete_when=concurrent.futures.FIRST_EXCEPTION)
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment