Skip to content

Instantly share code, notes, and snippets.

@arthur-tacca
Forked from smurfix/wrap.py
Last active February 14, 2024 17:08
Show Gist options
  • Save arthur-tacca/6c676a21d0dcc0582edb50c9c2aa3e3c to your computer and use it in GitHub Desktop.
Save arthur-tacca/6c676a21d0dcc0582edb50c9c2aa3e3c to your computer and use it in GitHub Desktop.
Trio: results-gathering nursery wrapper
# Original idea by smurfix: https://gist.github.com/smurfix/0130817fa5ba6d3bb4a0f00e4d93cf86
# aioresult variant: https://gist.github.com/arthur-tacca/5c717ae68ac037e72ae45fd1e9ca1345
from collections import deque
import math
import trio
class StreamResultsNursery:
def __init__(self, max_running_tasks=math.inf):
self._nursery = trio.open_nursery()
self._results = deque()
self._unfinished_tasks_count = 0 # Includes both running and waiting to run
self._capacity_limiter = trio.CapacityLimiter(max_running_tasks)
self._nm = None
self._parking_lot = trio.lowlevel.ParkingLot()
self._loop_finished = False
@property
def cancel_scope(self):
return self._nm.cancel_scope
@property
def max_running_tasks(self):
return self._capacity_limiter.total_tokens
@max_running_tasks.setter
def max_running_tasks(self, value):
self._capacity_limiter.total_tokens = value
@property
def running_tasks_count(self):
return self._capacity_limiter.borrowed_tokens
async def __aenter__(self):
self._nm = await self._nursery.__aenter__()
return self
def __aexit__(self, *exc):
return self._nursery.__aexit__(*exc)
async def _wrap(self, p, a, task_status=trio.TASK_STATUS_IGNORED):
try:
async with self._capacity_limiter:
task_status.started()
self._results.append(await p(*a))
finally:
self._unfinished_tasks_count -= 1
self._parking_lot.unpark()
def start_soon(self, p, *a):
if self._nm is None:
raise RuntimeError("Enter context manager before starting tasks")
if self._loop_finished:
raise RuntimeError("Loop over results has already completed")
self._unfinished_tasks_count += 1
self._nm.start_soon(self._wrap, p, a)
async def start(self, p, *a):
if self._nm is None:
raise RuntimeError("Enter context manager before starting tasks")
if self._loop_finished:
raise RuntimeError("Loop over results has already completed")
self._unfinished_tasks_count += 1
await self._nm.start(self._wrap, p, a)
def __aiter__(self):
return self
async def __anext__(self):
await trio.lowlevel.checkpoint() # Ensure this function is always a checkpoint
while len(self._results) == 0 and self._unfinished_tasks_count != 0:
await self._parking_lot.park() # Need to wait for a result to be produced
if self._results:
return self._results.popleft()
self._loop_finished = True
raise StopAsyncIteration # All tasks done and all results retrieved
if __name__ == "__main__":
import random
async def rand():
sleep_length = random.random()
try:
print(f"Starting: {sleep_length}")
await trio.sleep(sleep_length)
print(f"Finished: {sleep_length}")
return sleep_length
finally:
print(f"Done: {sleep_length}")
async def main(count):
async with trio.open_nursery() as outer_nursery:
async with StreamResultsNursery(max_running_tasks=3) as N:
for i in range(count):
print(f"Starting task {i}")
N.start_soon(rand)
i = 0
async for rn in N:
i += 1
print(f"Got {i}: {rn}\n")
if i == count:
print(f"starting extra task")
N.start_soon(rand)
trio.run(main,10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment