Skip to content

Instantly share code, notes, and snippets.

@valsteen
Last active July 25, 2022 09:48
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save valsteen/ea51e3259e65295890bed6813161bbf4 to your computer and use it in GitHub Desktop.
Save valsteen/ea51e3259e65295890bed6813161bbf4 to your computer and use it in GitHub Desktop.
How to limit concurrency with Python asyncio?
import asyncio
from typing import Awaitable, Callable, Coroutine, Iterator
from asyncio_pool import AioPool
import pytest as pytest
from more_itertools import peekable
"""
Different approaches to "How to limit concurrency with Python asyncio?"
https://stackoverflow.com/questions/48483348/how-to-limit-concurrency-with-python-asyncio/48484593#48484593
test_gather_with_concurrency demonstrates if those methods work as expected
Problem statement:
Define a function whose signature is:
```
async def gather_with_concurrency(
concurrency: int, coroutines: Iterator[Coroutine]
):
```
and fulfils those invariants:
- when returning from the function, all coroutines are completed
- the coroutines are executed concurrently with a maximum concurrency of `concurrency`
- a slower coroutine does not prevent coroutines that follow to be scheduled, as long as maximum concurrency
is not reached
"""
async def gather_with_concurrency_adam(
concurrency: int, coroutines: Iterator[Coroutine]
):
semaphore = asyncio.Semaphore(concurrency)
tasks = []
for coroutine in coroutines:
async with semaphore:
tasks.append(asyncio.create_task(coroutine))
await asyncio.gather(*tasks)
async def gather_with_concurrency_emin(concurrency: int, coros: Iterator[Coroutine]):
"""Gather asyncio coroutines with concurrency."""
semaphore = asyncio.Semaphore(concurrency)
async def sem_task(task: Coroutine):
async with semaphore:
return await task
return await asyncio.gather(*(sem_task(task) for task in coros))
async def gather_with_concurrency_aiopool(concurrency: int, coros: Iterator[Coroutine]):
# adapted from https://stackoverflow.com/a/57381896/34871
pool = AioPool(size=concurrency)
coros = peekable(coros)
if coros.peek(None):
await pool.map(lambda f: f, coros)
async def gather_so_anwser(concurrency: int, coroutines: Iterator[Coroutine]):
# adapted from https://stackoverflow.com/a/48484593/34871 part1
pending = set()
for coroutine in coroutines:
if len(pending) >= concurrency:
_, pending = await asyncio.tasks.wait(
pending, return_when=asyncio.tasks.FIRST_COMPLETED
)
pending.add(asyncio.create_task(coroutine))
if len(pending) > 0:
await asyncio.tasks.wait(pending)
async def gather_so_anwser_part2(concurrency: int, coros: Iterator[Coroutine]):
# adapted from https://stackoverflow.com/a/48484593/34871 part2
queue = asyncio.Queue()
async def worker():
while True:
await (await queue.get())
queue.task_done()
workers = [asyncio.create_task(worker()) for _ in range(concurrency)]
for coro in coros:
await queue.put(coro)
await queue.join() # wait for all tasks to be processed
for worker in workers:
worker.cancel()
await asyncio.gather(*workers, return_exceptions=True)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"concurrency,size", ((1, 1), (10, 1), (10, 101), (10, 0), (10, 10))
)
@pytest.mark.parametrize(
"method",
(
gather_with_concurrency_adam,
gather_with_concurrency_emin,
gather_with_concurrency_aiopool,
gather_so_anwser,
gather_so_anwser_part2,
),
)
async def test_gather_with_concurrency(
concurrency: int,
size: int,
method: Callable[[int, Iterator[Coroutine]], Awaitable[None]],
):
done = []
pending = set()
max_concurrency = 0
async def getter(i):
nonlocal max_concurrency
pending.add(i)
max_concurrency = max(len(pending), max_concurrency)
# reverse-order completion, to assess that concurrency is happening
await asyncio.sleep(1 - i / 10.0)
pending.remove(i)
done.append(i)
await method(concurrency, (getter(i) for i in range(size)))
assert len(pending) == 0
if size >= concurrency:
assert (
max_concurrency == concurrency
), "expected maximum concurrency {}, got {} instead".format(
concurrency, max_concurrency
)
if size > 1:
assert done != sorted(done)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment