Skip to content

Instantly share code, notes, and snippets.

@vxgmichel
Created May 16, 2024 15:54
Show Gist options
  • Save vxgmichel/e3203e514110994bcf3658dc3d5e5838 to your computer and use it in GitHub Desktop.
Save vxgmichel/e3203e514110994bcf3658dc3d5e5838 to your computer and use it in GitHub Desktop.
Merge of async generators using anyio
import random
import asyncio
from contextlib import asynccontextmanager
from typing import AsyncIterable, AsyncIterator, TypeVar
from anyio import create_memory_object_stream, create_task_group
from anyio.abc import ObjectReceiveStream, TaskStatus
from anyio.streams.memory import MemoryObjectSendStream
T = TypeVar("T")
@asynccontextmanager
async def amerge(
*sources: AsyncIterable[T],
max_buffer_size: float = 0,
) -> AsyncIterator[ObjectReceiveStream[T]]:
async def task(
sender: MemoryObjectSendStream[T],
source: AsyncIterable[T],
task_status: TaskStatus,
):
try:
async with sender.clone() as sender:
task_status.started()
async for item in source:
await sender.send(item)
finally:
if hasattr(source, "aclose"):
await source.aclose()
sender, receiver = create_memory_object_stream[T](max_buffer_size)
async with receiver:
async with create_task_group() as task_group:
async with sender:
for source in sources:
await task_group.start(task, sender, source)
yield receiver
task_group.cancel_scope.cancel()
async def main():
async def source(i: int):
try:
for x in range(10):
await asyncio.sleep(random.random())
yield (i, x)
finally:
print(f"[Source {i}] Done")
sources = [source(i) for i in range(5)]
async with amerge(*sources) as receiver:
async for id, value in receiver:
print(f"[Source {id}] Item {value}")
if value == 8:
await asyncio.sleep(1)
break
print("Done")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment