Skip to content

Instantly share code, notes, and snippets.

@HoverHell
Last active December 4, 2022 12:23
Show Gist options
  • Save HoverHell/74008be8f98a806dfcdca4316267a296 to your computer and use it in GitHub Desktop.
Save HoverHell/74008be8f98a806dfcdca4316267a296 to your computer and use it in GitHub Desktop.
anyio combined (interleaved) recieve stream
#!/usr/bin/env python3
"""
See:
https://stackoverflow.com/a/74661973/62821
"""
from __future__ import annotations
import contextlib
import dataclasses
from collections.abc import AsyncGenerator, Sequence
from types import TracebackType
from typing import Generic, TypeVar
import anyio
import anyio.abc
import anyio.streams.text
SCRIPT = r"""
for idx in $(seq 1 5); do
printf "%s " "$idx"
date -Ins
sleep 0.08
done
echo "."
"""
CMD = ["bash", "-x", "-c", SCRIPT]
def print_data(data: str, is_stderr: bool) -> None:
print(f"{int(is_stderr)}: {data!r}")
T_ACMWrap = TypeVar("T_ACMWrap", bound="ACMWrap")
class ACMWrap:
"""
Helper superclass that calls `self._enter_contexts` with a `contextlib.AsyncExitStack`
to conveniently do assorted initialization.
Example Usage:
>>> class Foo(ACMWrap):
... async def _enter_contexts(self, acm: contextlib.AsyncExitStack) -> Foo:
... await super()._enter_contexts(acm)
... self._tg = await acm.enter_async_context(anyio.create_task_group())
...
>>> async def amain() -> None:
... async with Foo() as foo:
... print(foo)
... print(foo._acm)
... print(foo._tg)
...
>>> anyio.run(amain)
<cmd_streamed_anyio.Foo object at 0x...>
<contextlib.AsyncExitStack object at 0x...>
<anyio...TaskGroup object at 0x...>
"""
_acm_raw: contextlib.AsyncExitStack | None = None
@property
def _acm(self) -> contextlib.AsyncExitStack:
if self._acm_raw is None:
raise RuntimeError("Expected to be `__aenter__`ed")
return self._acm_raw
async def _enter_contexts(self, acm: contextlib.AsyncExitStack) -> None:
"""Actual initialization place, for override in subclasses"""
async def __aenter__(self: T_ACMWrap) -> T_ACMWrap:
if self._acm_raw is not None:
raise RuntimeError("Already entered")
acm = contextlib.AsyncExitStack()
# Doesn't actually do anything, but done for the sake of following the protocol.
await acm.__aenter__()
self._acm_raw = acm
try:
await self._enter_contexts(acm)
except BaseException as exc: # `BaseException` because e.g. timeouts do not absolve of resource-closing.
# Make sure to unroll whatever was initialized.
await self.__aexit__(type(exc), exc, exc.__traceback__)
raise
return self
async def __aexit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> bool:
acm = self._acm
self._acm_raw = None
return await acm.__aexit__(exc_type, exc_val, exc_tb)
T_Item = TypeVar("T_Item") # TODO: covariant=True?
@dataclasses.dataclass(eq=False)
class CombinedReceiveStream(Generic[T_Item], ACMWrap):
"""Combines multiple streams into a single one, annotating each item with position index of the origin stream"""
streams: Sequence[anyio.abc.ObjectReceiveStream[T_Item]]
max_buffer_size_items: int = 32
_tg: anyio.abc.TaskGroup | None = None
def __post_init__(self) -> None:
self._queue_send, self._queue_receive = anyio.create_memory_object_stream(
max_buffer_size=self.max_buffer_size_items,
# Should be: `item_type=tuple[int, T_Item] | None`
)
self._pending = set(range(len(self.streams)))
@contextlib.asynccontextmanager
async def _manage_tasks(self, tg: anyio.abc.TaskGroup) -> AsyncGenerator[None, None]:
try:
for idx in self._pending:
tg.start_soon(self._copier, idx)
yield None
finally:
tg.cancel_scope.cancel()
async def _enter_contexts(self, acm: contextlib.AsyncExitStack) -> None:
"""Actual initialization place, for override in subclasses"""
tg = await acm.enter_async_context(anyio.create_task_group())
self._tg = tg
await acm.enter_async_context(self._manage_tasks(tg))
async def _copier(self, idx: int) -> None:
stream = self.streams[idx]
async for item in stream:
await self._queue_send.send((idx, item))
self._pending.remove(idx)
if not self._pending:
await self._queue_send.aclose()
async def receive(self) -> tuple[int, T_Item]:
return await self._queue_receive.receive()
def __aiter__(self):
return self
async def __anext__(self):
try:
return await self.receive()
except anyio.EndOfStream:
raise StopAsyncIteration() from None
async def amain(max_buffer_size_items: int = 32) -> None:
async with await anyio.open_process(CMD) as proc:
assert proc.stdout is not None
assert proc.stderr is not None
raw_streams = [proc.stdout, proc.stderr]
idx_to_is_stderr = {0: False, 1: True} # just making it explicit
streams = [anyio.streams.text.TextReceiveStream(stream) for stream in raw_streams]
async with CombinedReceiveStream(streams, max_buffer_size_items=max_buffer_size_items) as outputs:
async for idx, data in outputs:
is_stderr = idx_to_is_stderr[idx]
print_data(data, is_stderr=is_stderr)
def main():
anyio.run(amain)
if __name__ == "__main__":
main()
@smurfix
Copy link

smurfix commented Dec 4, 2022

Line 92, don't you mean acm_raw?

@HoverHell
Copy link
Author

HoverHell commented Dec 4, 2022

Line 92, don't you mean acm_raw?

No, using self._acm rather than self._acm_raw also does the check for "exiting without entering". After all, can't do acm.__aexit__ without some sort of a check for None, which is exactly what's done in the self._acm property.

Although I agree it would be more clear to do a separate explicit check in __aexit__.

@smurfix
Copy link

smurfix commented Dec 4, 2022

Ah, true, I missed that property handler.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment