Skip to content

Instantly share code, notes, and snippets.

@smurfix
Last active February 14, 2024 18:15
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save smurfix/0130817fa5ba6d3bb4a0f00e4d93cf86 to your computer and use it in GitHub Desktop.
Save smurfix/0130817fa5ba6d3bb4a0f00e4d93cf86 to your computer and use it in GitHub Desktop.
Trio: results-gathering nursery wrapper
#!/usr/bin/python3
import trio
import outcome
from contextlib import asynccontextmanager
class StreamResultsNursery:
def __init__(self, max_buffer_size=1):
self.nursery = trio.open_nursery()
self.max_buffer_size = max_buffer_size
self.res_in, self.res_out = trio.open_memory_channel(self.max_buffer_size)
self._waiting = 1
self._loop = False
@property
def cancel_scope(self):
return self.nursery.cancel_scope
async def __aenter__(self):
self.nm = await self.nursery.__aenter__()
return self
def __aexit__(self, *exc):
return self.nursery.__aexit__(*exc)
def start_soon(self, p, *a):
self.nm.start_soon(self._wrap, p, a)
async def _wrap(self, p, a):
try:
await self.res_in.send(await p(*a))
finally:
self._waiting -= 1
async def _wrap_ts(self, p, a, task_status):
try:
await self.res_in.send(await p(*a, task_status=task_status))
finally:
self._waiting -= 1
async def start(self, p, *a):
self._waiting += 1
if self.res_in is None:
await self.nm.start(p,*a)
else:
await self.nm.start(self._wrap_ts, p, a)
def start_soon(self, p, *a):
self._waiting += 1
if self.res_in is None:
self.nm.start_soon(p,*a)
else:
self.nm.start_soon(self._wrap, p, a)
def __aiter__(self):
if not self._loop:
self._loop = True
self._waiting -= 1
return self
async def __anext__(self):
if self.res_out is None:
raise StopAsyncIteration # never started
try:
if self._waiting:
return await self.res_out.receive()
else:
return self.res_out.receive_nowait()
except (trio.WouldBlock,trio.EndOfChannel):
raise StopAsyncIteration # never started
if __name__ == "__main__":
# test code
import random
async def rand():
await trio.sleep(random.random())
return random.random()
async def main(n):
async with StreamResultsNursery() as N:
for _ in range(10):
N.start_soon(rand)
async for rn in N:
print(rn)
trio.run(main,10)
@arthur-tacca
Copy link

This has a race condition which means the demo code will sometimes hang once all the tasks are done (about ~1 in 4 times on my computer).

The problem is, in the _wrap() function, there are 3 steps: (1) await p(*a), (2) await self.res_in.send(...), (3) self._waiting -= 1

Step (2) can pre-empted, because it's a checkpoint, but the checkpoint happens after sending the value (if the channel is not full). The problem happens if, in the final task to finish, the next thing to run is the final loop iteration (rather than step (3) for that task). In that case, it will initially work OK - it will pull off this new value successfully - but when it loops back around it still see self._waiting as truethy and wait on await self.res_out.receive() (rather than taking the sync path and then hitting trio.WouldBlock). There is a handler for trio.EndOfChannel but the channel is never closed so it waits there forever, even after the task is woken to finally complete step (3).

(I had initially wondered why you have any of that sync stuff in the first place instead of just closing the channel when the tasks are done, but I realised this is to allow more tasks to start in during the loop when all the existing ones have finished, which wouldn't be possible if the channel was closed. It's pity there's no "unclose"!)

The smallest change to fix this is to swap steps (2) and (3):

    async def _wrap(self, p, a):
        try:
            result = await p(*a)
        finally:
            self._waiting -= 1
        await self.res_in.send(result)

Even then, I don't feel really confident about it. What if the async for is happening in a separate outer nursery, and this inner nursery is cancelled? Then self._waiting will drop to 0 but nothing will be sent to the channel so the loop will still hang. In any case, it just feels like this is more complicated than using the trio primitives directly rather than trying to force a memory channel to deliver this behaviour.

Debatably, there is a second bug in this code: The "max_buffer_size" effectively does nothing. It only blocks step (2) above (sending the result) rather than running the tasks themselves in step (1). Essentially, the queue is always unbounded regardless of this parameter, it's just the part of the queue is in the form of completed but still-waiting tasks rather than the memory channel.

I have posted a gist that fixes these problems: it uses a deque of results a parking lot rather than a memory channel, and uses a trio.CapacityLimiter for limiting how many tasks are running. (Each waiting task still consumes a Trio task rather than just an entry in a container, which is a bit wasteful, but this is by far the simplest implementation.) It also allows waiting for a task to start with StreamResultsNursery.start(), which allows backpressure to task starters.

Here's my attempted fix: https://gist.github.com/arthur-tacca/6c676a21d0dcc0582edb50c9c2aa3e3c

I have also made a variant of this that uses aioresult.ResultCapture for the values sent to the loop, rather than the raw return values, which lets you get a little information about the task (the routine and args are properties) and can express when the task finished in exception rather than a regular return (or even wasn't run at all). I'm not sure it's useful but it was straighforward to do and having got this far it seemed a pity not to go the final step.

Here's the aioresult variant: https://gist.github.com/arthur-tacca/5c717ae68ac037e72ae45fd1e9ca1345

@smurfix
Copy link
Author

smurfix commented Feb 14, 2024

Thanks for the correction!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment