Last active
April 10, 2024 18:52
-
-
Save ededejr/62edae451366e6bfce78f6deb34b2079 to your computer and use it in GitHub Desktop.
a simple wait group written before discovering asyncio.TaskGroup
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import logging | |
from typing import Any, Callable, Coroutine, List, Optional | |
logger = logging.getLogger(__name__) | |
class Waitgroup: | |
def __init__(self, timeout: Optional[float] = None): | |
self.timeout = timeout | |
self.futures: List[asyncio.Future] = [] | |
def create(self) -> asyncio.Future: | |
future = asyncio.Future() | |
self.futures.append(future) | |
return future | |
def cancel(self): | |
"""Cancels all futures managed by this Waitgroup.""" | |
for fut in self.futures: | |
if not fut.done(): | |
fut.cancel() | |
async def wait(self, timeout: Optional[float] = None): | |
"""Waits for all saved futures to resolve in parallel, with an optional timeout.""" | |
if not self.futures: | |
return | |
# wait for all futures to complete | |
done, pending = await asyncio.wait( | |
self.futures, timeout=timeout or self.timeout, return_when=asyncio.ALL_COMPLETED | |
) | |
# log exceptions on completed futures | |
for fut in done: | |
exception = fut.exception() | |
if exception is not None: | |
logger.error(f"Future ended with exception: {exception}") | |
self.cancel() # cancel all remaining futures | |
raise exception | |
# cancel any pending futures if timedout | |
for fut in pending: | |
fut.cancel() | |
def schedule(self, coro: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs): | |
"""Schedules the execution of a coroutine function which resolves a future.""" | |
future = self.create() | |
async def task_wrapper(): | |
try: | |
result = await coro(*args, **kwargs) | |
future.set_result(result) | |
except Exception as e: | |
future.set_exception(e) | |
asyncio.create_task(task_wrapper()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from unittest.mock import patch | |
import pytest | |
import Waitgroup, logger | |
# Helper coroutine for testing | |
def schedule_run(wg: Waitgroup, delay, result=None, exception=None): | |
async def run(result=None, exception=None): | |
await asyncio.sleep(delay) | |
if exception: | |
raise exception | |
return result | |
wg.schedule(run, result=result, exception=exception) | |
@pytest.mark.asyncio | |
async def test_waitgroup_initialization(): | |
wg = Waitgroup() | |
assert wg.timeout is None | |
assert len(wg.futures) == 0 | |
wg_with_timeout = Waitgroup(timeout=10) | |
assert wg_with_timeout.timeout == 10 | |
@pytest.mark.asyncio | |
async def test_future_creation(): | |
wg = Waitgroup() | |
future = wg.create() | |
assert isinstance(future, asyncio.Future) | |
assert len(wg.futures) == 1 | |
@pytest.mark.asyncio | |
async def test_cancel_all_futures(): | |
wg = Waitgroup() | |
future1 = wg.create() | |
future2 = wg.create() | |
wg.cancel() | |
assert future1.cancelled() | |
assert future2.cancelled() | |
@pytest.mark.asyncio | |
async def test_wait_no_futures(): | |
wg = Waitgroup() | |
await wg.wait() # Should not raise any exceptions | |
@pytest.mark.asyncio | |
async def test_wait_futures_complete_before_timeout(): | |
wg = Waitgroup() | |
for _ in range(3): | |
schedule_run(wg, 0.1, result=True) | |
await wg.wait(timeout=1) | |
assert all(f.done() and not f.cancelled() for f in wg.futures) | |
@pytest.mark.asyncio | |
async def test_wait_futures_timeout(): | |
wg = Waitgroup(timeout=0.1) | |
for _ in range(3): | |
wg.create() # Futures that won't be set | |
await wg.wait() | |
# Check if futures are cancelled due to timeout | |
assert all(f.cancelled() for f in wg.futures) | |
@pytest.mark.asyncio | |
async def test_wait_exceptions_are_logged_and_raised(): | |
wg = Waitgroup() | |
with patch.object(logger, "error") as mock_logger: | |
schedule_run(wg, 0.1, exception=Exception("Test Exception")) | |
with pytest.raises(Exception) as exc_info: | |
await wg.wait() | |
assert str(exc_info.value) == "Test Exception" | |
mock_logger.assert_called_once() | |
@pytest.mark.asyncio | |
async def test_schedule_coroutine(): | |
wg = Waitgroup() | |
schedule_run(wg, 0.1, result="success") | |
await wg.wait() | |
assert all(f.done() and f.result() == "success" for f in wg.futures) | |
# Testing with an exception | |
wg = Waitgroup() | |
schedule_run(wg, 0.1, exception=Exception("Failure")) | |
with pytest.raises(Exception) as exc_info: | |
await wg.wait() | |
assert "Failure" in str(exc_info.value) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment