Skip to content

Instantly share code, notes, and snippets.

@ededejr
Last active April 10, 2024 18:52
Show Gist options
  • Save ededejr/62edae451366e6bfce78f6deb34b2079 to your computer and use it in GitHub Desktop.
Save ededejr/62edae451366e6bfce78f6deb34b2079 to your computer and use it in GitHub Desktop.
a simple wait group written before discovering asyncio.TaskGroup
import asyncio
import logging
from typing import Any, Callable, Coroutine, List, Optional
logger = logging.getLogger(__name__)
class Waitgroup:
def __init__(self, timeout: Optional[float] = None):
self.timeout = timeout
self.futures: List[asyncio.Future] = []
def create(self) -> asyncio.Future:
future = asyncio.Future()
self.futures.append(future)
return future
def cancel(self):
"""Cancels all futures managed by this Waitgroup."""
for fut in self.futures:
if not fut.done():
fut.cancel()
async def wait(self, timeout: Optional[float] = None):
"""Waits for all saved futures to resolve in parallel, with an optional timeout."""
if not self.futures:
return
# wait for all futures to complete
done, pending = await asyncio.wait(
self.futures, timeout=timeout or self.timeout, return_when=asyncio.ALL_COMPLETED
)
# log exceptions on completed futures
for fut in done:
exception = fut.exception()
if exception is not None:
logger.error(f"Future ended with exception: {exception}")
self.cancel() # cancel all remaining futures
raise exception
# cancel any pending futures if timedout
for fut in pending:
fut.cancel()
def schedule(self, coro: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs):
"""Schedules the execution of a coroutine function which resolves a future."""
future = self.create()
async def task_wrapper():
try:
result = await coro(*args, **kwargs)
future.set_result(result)
except Exception as e:
future.set_exception(e)
asyncio.create_task(task_wrapper())
import asyncio
from unittest.mock import patch
import pytest
import Waitgroup, logger
# Helper coroutine for testing
def schedule_run(wg: Waitgroup, delay, result=None, exception=None):
async def run(result=None, exception=None):
await asyncio.sleep(delay)
if exception:
raise exception
return result
wg.schedule(run, result=result, exception=exception)
@pytest.mark.asyncio
async def test_waitgroup_initialization():
wg = Waitgroup()
assert wg.timeout is None
assert len(wg.futures) == 0
wg_with_timeout = Waitgroup(timeout=10)
assert wg_with_timeout.timeout == 10
@pytest.mark.asyncio
async def test_future_creation():
wg = Waitgroup()
future = wg.create()
assert isinstance(future, asyncio.Future)
assert len(wg.futures) == 1
@pytest.mark.asyncio
async def test_cancel_all_futures():
wg = Waitgroup()
future1 = wg.create()
future2 = wg.create()
wg.cancel()
assert future1.cancelled()
assert future2.cancelled()
@pytest.mark.asyncio
async def test_wait_no_futures():
wg = Waitgroup()
await wg.wait() # Should not raise any exceptions
@pytest.mark.asyncio
async def test_wait_futures_complete_before_timeout():
wg = Waitgroup()
for _ in range(3):
schedule_run(wg, 0.1, result=True)
await wg.wait(timeout=1)
assert all(f.done() and not f.cancelled() for f in wg.futures)
@pytest.mark.asyncio
async def test_wait_futures_timeout():
wg = Waitgroup(timeout=0.1)
for _ in range(3):
wg.create() # Futures that won't be set
await wg.wait()
# Check if futures are cancelled due to timeout
assert all(f.cancelled() for f in wg.futures)
@pytest.mark.asyncio
async def test_wait_exceptions_are_logged_and_raised():
wg = Waitgroup()
with patch.object(logger, "error") as mock_logger:
schedule_run(wg, 0.1, exception=Exception("Test Exception"))
with pytest.raises(Exception) as exc_info:
await wg.wait()
assert str(exc_info.value) == "Test Exception"
mock_logger.assert_called_once()
@pytest.mark.asyncio
async def test_schedule_coroutine():
wg = Waitgroup()
schedule_run(wg, 0.1, result="success")
await wg.wait()
assert all(f.done() and f.result() == "success" for f in wg.futures)
# Testing with an exception
wg = Waitgroup()
schedule_run(wg, 0.1, exception=Exception("Failure"))
with pytest.raises(Exception) as exc_info:
await wg.wait()
assert "Failure" in str(exc_info.value)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment