Skip to content

Instantly share code, notes, and snippets.

@smurfix
Last active February 14, 2024 18:15
Show Gist options
  • Save smurfix/0130817fa5ba6d3bb4a0f00e4d93cf86 to your computer and use it in GitHub Desktop.
Save smurfix/0130817fa5ba6d3bb4a0f00e4d93cf86 to your computer and use it in GitHub Desktop.
Trio: results-gathering nursery wrapper
#!/usr/bin/python3
import trio
import outcome
from contextlib import asynccontextmanager
class StreamResultsNursery:
def __init__(self, max_buffer_size=1):
self.nursery = trio.open_nursery()
self.max_buffer_size = max_buffer_size
self.res_in, self.res_out = trio.open_memory_channel(self.max_buffer_size)
self._waiting = 1
self._loop = False
@property
def cancel_scope(self):
return self.nursery.cancel_scope
async def __aenter__(self):
self.nm = await self.nursery.__aenter__()
return self
def __aexit__(self, *exc):
return self.nursery.__aexit__(*exc)
def start_soon(self, p, *a):
self.nm.start_soon(self._wrap, p, a)
async def _wrap(self, p, a):
try:
await self.res_in.send(await p(*a))
finally:
self._waiting -= 1
async def _wrap_ts(self, p, a, task_status):
try:
await self.res_in.send(await p(*a, task_status=task_status))
finally:
self._waiting -= 1
async def start(self, p, *a):
self._waiting += 1
if self.res_in is None:
await self.nm.start(p,*a)
else:
await self.nm.start(self._wrap_ts, p, a)
def start_soon(self, p, *a):
self._waiting += 1
if self.res_in is None:
self.nm.start_soon(p,*a)
else:
self.nm.start_soon(self._wrap, p, a)
def __aiter__(self):
if not self._loop:
self._loop = True
self._waiting -= 1
return self
async def __anext__(self):
if self.res_out is None:
raise StopAsyncIteration # never started
try:
if self._waiting:
return await self.res_out.receive()
else:
return self.res_out.receive_nowait()
except (trio.WouldBlock,trio.EndOfChannel):
raise StopAsyncIteration # never started
if __name__ == "__main__":
# test code
import random
async def rand():
await trio.sleep(random.random())
return random.random()
async def main(n):
async with StreamResultsNursery() as N:
for _ in range(10):
N.start_soon(rand)
async for rn in N:
print(rn)
trio.run(main,10)
@smurfix
Copy link
Author

smurfix commented Feb 14, 2024

Thanks for the correction!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment