Skip to content

Instantly share code, notes, and snippets.

@vxgmichel
Last active May 16, 2024 16:10
Show Gist options
  • Save vxgmichel/23469ee482aeba4a1c4d3cd66f1ac6c5 to your computer and use it in GitHub Desktop.
Save vxgmichel/23469ee482aeba4a1c4d3cd66f1ac6c5 to your computer and use it in GitHub Desktop.
Async map with task limit using anyio
import math
import asyncio
from contextlib import asynccontextmanager
from typing import AsyncIterable, AsyncIterator, Awaitable, Callable
from anyio import create_memory_object_stream, create_task_group, abc
from anyio.streams.memory import MemoryObjectReceiveStream
@asynccontextmanager
async def amap[
A, B
](
source: AsyncIterable[A],
corofn: Callable[[A], Awaitable[B]],
task_limit: float = math.inf,
) -> AsyncIterator[MemoryObjectReceiveStream[B]]:
async def source_task(task_status: abc.TaskStatus) -> None:
async with send_item_stream:
task_status.started()
async for item in source:
await send_token_stream.send(None)
await task_group.start(item_task, item)
async def item_task(item: A, task_status: abc.TaskStatus) -> None:
try:
async with send_item_stream.clone() as cloned_stream:
task_status.started()
result = await corofn(item)
await cloned_stream.send(result)
finally:
await receive_token_stream.receive()
send_token_stream, receive_token_stream = create_memory_object_stream[None](
max_buffer_size=task_limit
)
send_item_stream, receive_item_stream = create_memory_object_stream[B]()
async with receive_item_stream:
async with create_task_group() as task_group:
await task_group.start(source_task)
yield receive_item_stream
task_group.cancel_scope.cancel()
async def main():
async def input_gen() -> AsyncIterator[str]:
for char in "abc123xyz789":
await asyncio.sleep(0.1)
yield char
async def slow_task(item: str) -> str:
await asyncio.sleep(0.5)
return f"{item}_loaded"
print("Running without task limit")
async with amap(input_gen(), slow_task) as items:
async for item in items:
print(f"Received: {item}")
print("Running with task limit = 1")
async with amap(input_gen(), slow_task, task_limit=1) as items:
async for item in items:
print(f"Received: {item}")
print("Stopping after 3_loaded is received")
async with amap(input_gen(), slow_task) as items:
async for item in items:
print(f"Received: {item}")
if item == "3_loaded":
break
if __name__ == "__main__":
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment