Last active
December 4, 2022 12:23
-
-
Save HoverHell/74008be8f98a806dfcdca4316267a296 to your computer and use it in GitHub Desktop.
anyio combined (interleaved) recieve stream
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
See: | |
https://stackoverflow.com/a/74661973/62821 | |
""" | |
from __future__ import annotations | |
import contextlib | |
import dataclasses | |
from collections.abc import AsyncGenerator, Sequence | |
from types import TracebackType | |
from typing import Generic, TypeVar | |
import anyio | |
import anyio.abc | |
import anyio.streams.text | |
SCRIPT = r""" | |
for idx in $(seq 1 5); do | |
printf "%s " "$idx" | |
date -Ins | |
sleep 0.08 | |
done | |
echo "." | |
""" | |
CMD = ["bash", "-x", "-c", SCRIPT] | |
def print_data(data: str, is_stderr: bool) -> None: | |
print(f"{int(is_stderr)}: {data!r}") | |
T_ACMWrap = TypeVar("T_ACMWrap", bound="ACMWrap") | |
class ACMWrap: | |
""" | |
Helper superclass that calls `self._enter_contexts` with a `contextlib.AsyncExitStack` | |
to conveniently do assorted initialization. | |
Example Usage: | |
>>> class Foo(ACMWrap): | |
... async def _enter_contexts(self, acm: contextlib.AsyncExitStack) -> Foo: | |
... await super()._enter_contexts(acm) | |
... self._tg = await acm.enter_async_context(anyio.create_task_group()) | |
... | |
>>> async def amain() -> None: | |
... async with Foo() as foo: | |
... print(foo) | |
... print(foo._acm) | |
... print(foo._tg) | |
... | |
>>> anyio.run(amain) | |
<cmd_streamed_anyio.Foo object at 0x...> | |
<contextlib.AsyncExitStack object at 0x...> | |
<anyio...TaskGroup object at 0x...> | |
""" | |
_acm_raw: contextlib.AsyncExitStack | None = None | |
@property | |
def _acm(self) -> contextlib.AsyncExitStack: | |
if self._acm_raw is None: | |
raise RuntimeError("Expected to be `__aenter__`ed") | |
return self._acm_raw | |
async def _enter_contexts(self, acm: contextlib.AsyncExitStack) -> None: | |
"""Actual initialization place, for override in subclasses""" | |
async def __aenter__(self: T_ACMWrap) -> T_ACMWrap: | |
if self._acm_raw is not None: | |
raise RuntimeError("Already entered") | |
acm = contextlib.AsyncExitStack() | |
# Doesn't actually do anything, but done for the sake of following the protocol. | |
await acm.__aenter__() | |
self._acm_raw = acm | |
try: | |
await self._enter_contexts(acm) | |
except BaseException as exc: # `BaseException` because e.g. timeouts do not absolve of resource-closing. | |
# Make sure to unroll whatever was initialized. | |
await self.__aexit__(type(exc), exc, exc.__traceback__) | |
raise | |
return self | |
async def __aexit__( | |
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None | |
) -> bool: | |
acm = self._acm | |
self._acm_raw = None | |
return await acm.__aexit__(exc_type, exc_val, exc_tb) | |
T_Item = TypeVar("T_Item") # TODO: covariant=True? | |
@dataclasses.dataclass(eq=False) | |
class CombinedReceiveStream(Generic[T_Item], ACMWrap): | |
"""Combines multiple streams into a single one, annotating each item with position index of the origin stream""" | |
streams: Sequence[anyio.abc.ObjectReceiveStream[T_Item]] | |
max_buffer_size_items: int = 32 | |
_tg: anyio.abc.TaskGroup | None = None | |
def __post_init__(self) -> None: | |
self._queue_send, self._queue_receive = anyio.create_memory_object_stream( | |
max_buffer_size=self.max_buffer_size_items, | |
# Should be: `item_type=tuple[int, T_Item] | None` | |
) | |
self._pending = set(range(len(self.streams))) | |
@contextlib.asynccontextmanager | |
async def _manage_tasks(self, tg: anyio.abc.TaskGroup) -> AsyncGenerator[None, None]: | |
try: | |
for idx in self._pending: | |
tg.start_soon(self._copier, idx) | |
yield None | |
finally: | |
tg.cancel_scope.cancel() | |
async def _enter_contexts(self, acm: contextlib.AsyncExitStack) -> None: | |
"""Actual initialization place, for override in subclasses""" | |
tg = await acm.enter_async_context(anyio.create_task_group()) | |
self._tg = tg | |
await acm.enter_async_context(self._manage_tasks(tg)) | |
async def _copier(self, idx: int) -> None: | |
stream = self.streams[idx] | |
async for item in stream: | |
await self._queue_send.send((idx, item)) | |
self._pending.remove(idx) | |
if not self._pending: | |
await self._queue_send.aclose() | |
async def receive(self) -> tuple[int, T_Item]: | |
return await self._queue_receive.receive() | |
def __aiter__(self): | |
return self | |
async def __anext__(self): | |
try: | |
return await self.receive() | |
except anyio.EndOfStream: | |
raise StopAsyncIteration() from None | |
async def amain(max_buffer_size_items: int = 32) -> None: | |
async with await anyio.open_process(CMD) as proc: | |
assert proc.stdout is not None | |
assert proc.stderr is not None | |
raw_streams = [proc.stdout, proc.stderr] | |
idx_to_is_stderr = {0: False, 1: True} # just making it explicit | |
streams = [anyio.streams.text.TextReceiveStream(stream) for stream in raw_streams] | |
async with CombinedReceiveStream(streams, max_buffer_size_items=max_buffer_size_items) as outputs: | |
async for idx, data in outputs: | |
is_stderr = idx_to_is_stderr[idx] | |
print_data(data, is_stderr=is_stderr) | |
def main(): | |
anyio.run(amain) | |
if __name__ == "__main__": | |
main() |
Line 92, don't you mean
acm_raw
?
No, using self._acm
rather than self._acm_raw
also does the check for "exiting without entering". After all, can't do acm.__aexit__
without some sort of a check for None
, which is exactly what's done in the self._acm
property.
Although I agree it would be more clear to do a separate explicit check in __aexit__
.
Ah, true, I missed that property handler.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Line 92, don't you mean
acm_raw
?