Skip to content

Instantly share code, notes, and snippets.

@ktbarrett
Last active April 10, 2023 19:48
Show Gist options
  • Save ktbarrett/178cfed4f9963642eaf9ad27a3e32e16 to your computer and use it in GitHub Desktop.
Save ktbarrett/178cfed4f9963642eaf9ad27a3e32e16 to your computer and use it in GitHub Desktop.
Join blocks
async def test(dut):
with TaskManager() as tm:
@tm.fork
async def stimulate():
pass # stimulate an interface
@tm.fork
async def analyze():
pass # analyze an output
await tm.join_all()
async def test(dut):
tm = await join_any(
RisingEdge(dut.valid),
ClockCycles(dut.clk),
)
assert tm[0].done() # otherwise timeout
from contextlib import AbstractContextManager
from typing import (
Any,
Awaitable,
Callable,
Coroutine,
Dict,
Iterator,
List,
Mapping,
TypeVar,
overload,
)
from cocotbext.compat import Event, Task, fork
K = TypeVar("K")
Self = TypeVar("Self")
class TaskManager(Mapping[K, Task], AbstractContextManager):
def __init__(self) -> None:
self._tasks: Dict[K, Task] = {}
self._error_squashers: Dict[K, Task] = {}
self._event = Event()
async def _waiter(self, awaitable: Awaitable) -> Any:
try:
return await awaitable
except Exception as e:
self._exception = e
finally:
self._event.set()
async def _error_squasher(self, task: Task) -> None:
try:
await task
except Exception:
pass
def add(self, id: object, awaitable: Awaitable[Any]) -> Task:
task = fork(self._waiter(awaitable))
if id in self._tasks:
raise ValueError("Duplicate IDs in TaskManager")
self._tasks[id] = task
self._error_squashers[id] = fork(self._error_squasher(task))
return task
@overload
def fork(self, __coro: Callable[[], Coroutine[Any, Any, Any]]) -> Task:
...
@overload
def fork(
self, __name: str
) -> Callable[[Callable[[], Coroutine[Any, Any, Any]]], Task]:
...
def fork(self, __coro): # type: ignore
if isinstance(__coro, str):
name = __coro
def decorator(coro: Callable[[], Coroutine[Any, Any, Any]]) -> Task:
return self.add(name, coro())
return decorator
else:
return self.add(__coro.__name__, __coro())
async def _join_next(self) -> None:
self._exception = None
self._event.clear()
await self._event.wait()
if self._exception:
raise self._exception
async def join_next(self) -> None:
if not all(task.done() for task in self._tasks.values()):
await self._join_next()
async def join_any(self) -> None:
if any(task.done() for task in self._tasks.values()):
return
await self._join_next()
async def join_all(self) -> None:
while not all(task.done() for task in self._tasks.values()):
await self._join_next()
def cancel_all(self) -> None:
for task in self._error_squashers.values():
if not task.done():
task.cancel()
for task in self._tasks.values():
if not task.done():
task.cancel()
def __getitem__(self, id: K) -> Task:
return self._tasks[id]
def __iter__(self) -> Iterator[K]:
return iter(self._tasks)
def __len__(self) -> int:
return len(self._tasks)
def __enter__(self: Self) -> Self:
return self
def __exit__(self, *exc_info: Any) -> None:
running_ids: List[K] = []
for id, task in self._tasks.items():
if not task.done():
running_ids.append(id)
if running_ids:
running_ids_str = ", ".join(repr(id) for id in running_ids)
raise RuntimeError(
f"TaskManager exited with still running Tasks: {running_ids_str}"
)
class _join_base(TaskManager):
def __init__(self, *args: Awaitables[Any], **kwargs: Awaitables[Any]) -> None:
super().__init__()
for i, arg in enumerate(args):
super().add(i, arg)
for name, arg in kwargs.items():
super().add(name, arg)
class join_any(_join_base):
def __await__(self) -> Generator[Any, Any, TaskManager]:
yield from super().join_any().__await__()
return self
class join_all(_join_base):
def __await__(self) -> Generator[Any, Any, TaskManager]:
yield from super().join_all().__await__()
return self
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment