Skip to content

Instantly share code, notes, and snippets.

@vxgmichel
Created April 19, 2024 11:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vxgmichel/23469ee482aeba4a1c4d3cd66f1ac6c5 to your computer and use it in GitHub Desktop.
Save vxgmichel/23469ee482aeba4a1c4d3cd66f1ac6c5 to your computer and use it in GitHub Desktop.
Async map with task limit using anyio
import math
import asyncio
from contextlib import asynccontextmanager
from typing import AsyncIterable, AsyncIterator, Awaitable, Callable
from anyio import create_memory_object_stream, create_task_group, abc
from anyio.streams.memory import MemoryObjectReceiveStream
@asynccontextmanager
async def amap[
A, B
](
source: AsyncIterable[A],
corofn: Callable[[A], Awaitable[B]],
task_limit: float = math.inf,
) -> AsyncIterator[MemoryObjectReceiveStream[B]]:
async def source_task(task_status: abc.TaskStatus) -> None:
async with send_item_stream:
task_status.started()
async for item in source:
await send_token_stream.send(None)
await task_group.start(item_task, item)
async def item_task(item: A, task_status: abc.TaskStatus) -> None:
try:
async with send_item_stream.clone() as cloned_stream:
task_status.started()
await cloned_stream.send(await corofn(item))
finally:
await receive_token_stream.receive()
async with create_task_group() as task_group:
send_token_stream, receive_token_stream = create_memory_object_stream[None](
max_buffer_size=task_limit
)
send_item_stream, receive_item_stream = create_memory_object_stream[B]()
async with receive_item_stream:
await task_group.start(source_task)
yield receive_item_stream
async def main():
async def input_gen() -> AsyncIterator[str]:
for char in "abc123xyz789":
await asyncio.sleep(0.1)
yield char
async def slow_task(item: str) -> str:
await asyncio.sleep(0.5)
return f"{item}_loaded"
print("Running without task limit")
async with amap(input_gen(), slow_task) as items:
async for item in items:
print(f"Received: {item}")
print("Running with task limit = 1")
async with amap(input_gen(), slow_task, task_limit=1) as items:
async for item in items:
print(f"Received: {item}")
if __name__ == "__main__":
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment