Skip to content

Instantly share code, notes, and snippets.

@arthur-tacca
Last active February 12, 2024 13:31
Show Gist options
  • Save arthur-tacca/32c9b5fa81294850cabc890f4a898a4e to your computer and use it in GitHub Desktop.
Save arthur-tacca/32c9b5fa81294850cabc890f4a898a4e to your computer and use it in GitHub Desktop.
ResultCapture class for Trio - now adapted into aioresult: https://github.com/arthur-tacca/aioresult
#
# -- An improved version of this code is now in the aioresult library --
#
# https://github.com/arthur-tacca/aioresult
#
#
# - aioresult has a ResultCapture class similar to the one below
# - It also has a Future class that allows manually setting the result, which shares a base class with ResultCapture
# - There is a utility function with a similar effect to StartableResultCapture below (but much simpler)
# - There are utility functions wait_any(), wait_all() and to_channel() (which also work with Future instances)
#
#
# It is recommended that you use aioresult instead. The gist below is just kept to avoid breaking links to it.
#
import trio
class TaskNotDoneException(Exception):
pass
class TaskWrappedException(Exception):
pass
class ResultCapture:
"""Captures the result of a task for later access.
If you are directly awaiting a task then there is no need to use this class, you can just use
the return value:
result1 = await foo1()
result2 = await foo2()
print("results:", result1, result2)
If you want to run your tasks in parallel then you would typically use a nursery, but then it's
harder to get hold of the results:
async with trio.open_nursery() as n:
n.start_soon(foo1)
n.start_soon(foo2)
# At this point the tasks have completed, but the results are lost
print("results: ??")
To get access to the results, the routines need to stash their results somewhere for you to
access later. ResultCapture is a simple helper to do this.
async with trio.open_nursery() as n:
r1 = ResultCapture.start_soon(n, foo1)
r2 = ResultCapture.start_soon(n, foo2)
# At this point the tasks have completed, and results are stashed in ResultCapture objects
print("results", r1.result, r2.result)
You can get very similar effect to asyncio's gather() function by using a nursery and an array
of ResultCapture objects:
async with trio.open_nursery() as n:
rs = [
ResultCapture.start_soon(n, foo, i)
for i in range(10)
]
print("results:", *[r.result for r in rs])
But ResultCapture is more flexible than gather e.g. you can use a dictionary with suitable key
rather than an array. You also benefit from the safer behaviour of Trio nurseries compared to
asyncio's gather.
Any exception thrown by the task will propagate out as usual, typically to the enclosing
nursery. Accessing the result attribute will then raise CapturedResultException, with the
original exception available as the __cause__ attribute (because it is raised using
raise ... from syntax).
"""
@classmethod
def start_soon(cls: type, nursery: trio.Nursery, routine, *args):
"""Runs the task in the given nursery and captures its result."""
task = cls(routine, *args)
nursery.start_soon(task.run)
return task
def __init__(self, routine, *args):
self._routine = routine
self._args = args
self._has_run_been_called = False
self._done_event = trio.Event()
self._result = None
self._exception = None
async def run(self):
"""Runs the routine and captures its result.
Typically you would use the start_soon() class method, which constructs the ResultCapture
and arranges for the run() method to be run in the given nursery. But it is possible to
manually construct the object and call the run() method in situations where that extra
control is useful.
"""
if self._has_run_been_called:
raise RuntimeError("ResultCapture.run() called multiple times")
self._has_run_been_called = True
try:
self._result = await self._routine(*self._args)
except BaseException as e:
self._exception = e
raise # Note the exception is allowed to propagate into user nursery
finally:
self._done_event.set()
@property
def result(self):
"""Returns the captured result of the task."""
if not self._done_event.is_set():
raise TaskNotDoneException(self)
if self._exception is not None:
raise TaskWrappedException(self) from self._exception
return self._result
@property
def exception(self):
"""Returns the exception raised by the task.
If the task completed by returning rather than raising an exception then this returns None.
If the task is not done yet then this raises TaskNotCompletedException.
This property returns the original unmodified exception. That is unlike the result property,
which raises a TaskWrappedException instead, with the __cause__ attribute set to the
original exception.
It is usually better design to use the result property and catch exception it throws.
However, this property can be useful in some situations e.g. filtering a list of TaskResult
objects.
"""
if not self._done_event.set():
raise TaskNotDoneException(self)
return self._exception
@property
def done(self):
"""Returns whether the task is done i.e. the result (or an exception) is captured."""
return self._done_event.is_set()
async def wait_done(self):
"""Waits until the task is done.
There are specialised situations where it may be useful to use this method to wait until
the task is done, typically where you are writing library code and you want to start a
routine in a user supplied nursery but wait for it in some other context.
Typically, though, it is much better design to wait for the task's nursery to complete.
Consider a nursery-based approach before using this method.
"""
await self._done_event.wait()
class _TaskStatus:
"""Helper class to allow StartableResultCapture to satisfy the Trio start() protocol.
An instance of this class is passed to the task_status parameter of the routine passed to
StartableResultCapture, allowing it to call task_status.started() as usual when it is ready.
This class is effectively a "friend class" (in C++ lingo) of StartableResultCapture, so it
directly accesses protected members of that class.
"""
def __init__(self, task, outer_task_status):
assert isinstance(task, StartableResultCapture)
self._task = task
self._outer_task_status = outer_task_status
def started(self, start_result=None):
if self._task._started_event.is_set():
raise RuntimeError("task_status.started() called multiple times")
self._task._start_result = start_result
self._task._started_event.set()
self._outer_task_status.started(start_result)
# TODO I'm not sure if this class would be useful, but it's a fun proof of concept.
class StartableResultCapture(ResultCapture):
"""Captures result of a task and allows waiting until it has finished starting.
This class allows running routines that were designed to be compatible with the Trio
Nursery.start() protocol. The regular ResultCapture class can run those routines but does not
give any information about when the task has started, whereas this class allows waiting for the
task startup to complete and allows fetching the start result (if the routine sets it).
"""
# TODO this function needs a much better name
@classmethod
def start_soon_nurseries(
cls, routine, *args, start_nursery: trio.Nursery, run_nursery: trio.Nursery
):
"""Starts a task in one nursery and then runs it in another.
It is perfectly possible to use StartableResultCapture.start_soon() to run a task and
capture its start result and overall result. However, that runs the whole task, including
its starting code, in a single nursery. This method allows running the startup code in one
nursery and the rest of the code in another. This can be useful for waiting until multiple
tasks have all finished their startup code, like in the snippet below.
async with trio.open_nursery as run_nursery:
async with trio.open_nursery as start_nursery:
rcs = [
StartableResultCapture.start_soon_nurseries(
foo, i, start_nursery=start_nursery, run_nursery=run_nursery
) for i in range(10)
]
print("Now all tasks have started; start values:")
print(*[rc.start_result for rc in rcs])
print("Overall results:", *[rc.result for rc in rcs])
"""
task = cls(routine, *args)
start_nursery.start_soon(run_nursery.start, task.run)
return task
def __init__(self, routine, *args):
super().__init__(routine, *args)
self._start_result = None
self._started_event = trio.Event()
async def run(self, task_status=trio.TASK_STATUS_IGNORED):
if self._has_run_been_called:
raise RuntimeError("StartableResultCapture.run() called multiple times")
self._has_run_been_called = True
try:
self._result = await self._routine(
*self._args, task_status=_TaskStatus(self, task_status)
)
except BaseException as e:
self._exception = e
raise # Note the exception is allowed to propagate into user nursery
finally:
self._done_event.set()
@property
def start_result(self):
"""Returns the start result i.e. the value passed to task_status.started()."""
if not self._started_event.is_set():
raise TaskNotDoneException(self)
return self._start_result
@property
def started(self):
"""Returns whether the task has started i.e. whether the task called task_status.started().
This can be False even if the done property is True. This happens if the task returns or
raises an exception without calling task_status.started() (most commonly because an error
was encountered while starting).
"""
return self._started_event.is_set()
async def wait_started(self):
"""Waits for the task to start i.e. waits for it to call task_status.started()."""
await self._started_event.wait()
async def run_a_bit(time_to_run, should_raise=False, task_status=trio.TASK_STATUS_IGNORED):
"""Waits a little while and prints some trace; used in the tests."""
print(f"Running for {time_to_run}s")
await trio.sleep(time_to_run/2)
task_status.started(time_to_run/2)
await trio.sleep(time_to_run/2)
print(f"Finished running for {time_to_run}s" + ("; raising..." if should_raise else ""))
if should_raise:
raise RuntimeError(time_to_run)
return time_to_run
async def test():
try:
async with trio.open_nursery() as n:
r1 = ResultCapture.start_soon(n, run_a_bit, 1)
assert not r1.done
try:
print("r1 result:", r1.result)
assert False, "expected r2.result to raise an exception"
except TaskNotDoneException as e:
print("Got not completed exception:", e.args[0]._args)
r2 = StartableResultCapture.start_soon(n, run_a_bit, 2)
assert not r2.started
await r2.wait_started()
assert r2.started and r2.start_result == 1
r3 = ResultCapture.start_soon(n, run_a_bit, 3, True)
async with trio.open_nursery() as inner:
r4 = StartableResultCapture.start_soon_nurseries(
run_a_bit, 2.5, start_nursery=inner, run_nursery=n
)
assert not r4.started
r5 = StartableResultCapture(run_a_bit, 1.5)
await n.start(r5.run)
assert r5.started and r5.start_result == 0.75
assert r4.started
print("startup nursery finished; start values:", r4.start_result, r5.start_result)
assert r4.start_result == 1.25 and not r4.done
assert r2.started and r2.result == 2
print("r2 started with value", r2.start_result)
assert r1.done
await r1.wait_done()
assert r1.done
print("r1 finished:", r1.result)
await r4.wait_done()
assert r4.done and r4.result == 2.5
except RuntimeError as e:
print("Caught RuntimeError:", e)
else:
assert False, "Expected r3 to raise an exception"
try:
print("r3 =", r3.result)
assert False, "Expected r3 to raise an exception"
except TaskWrappedException as e:
print("Got wrapped exception:", e)
inner = e.__cause__
print(f"With inner exception {type(inner).__name__}: {inner}")
if __name__ == "__main__":
trio.run(test)
@indigoviolet
Copy link

indigoviolet commented Oct 3, 2022

Thanks for this! I made a couple of minor tweaks to it when I used it today, that you may consider if you package this up.

  1. I made it impossible to call run directly, since I didn't have a need for it, eliminating the "called run twice" checks
  2. Inspired by https://gist.github.com/smurfix/0130817fa5ba6d3bb4a0f00e4d93cf86, I added a small NurseryManager wrapper that I can use in place of open_nursery, eliminating the need to pass the nursery around.
  3. I made the ResultCapture class awaitable
@dataclass
class ResultCapture(Awaitable[Any]):
    nursery: trio.Nursery
    f: AsyncFnType
    args: Iterable[ArgType]

    _done_event: trio.Event = field(init=False, default_factory=trio.Event)
    _result: Any = field(init=False)
    _exception: Optional[BaseException] = field(init=False, default=None)

    def __post_init__(self):
        self.nursery.start_soon(self._run)

    async def _run(self):
        try:
            self._result = await self.f(*self.args)
        except BaseException as e:
            self._exception = e
            raise
        finally:
            self._done_event.set()

    @property
    def result(self):
        if not self._done_event.is_set():
            raise TaskNotDoneException(self)
        if self._exception is not None:
            raise TaskFailedException(self) from self._exception

        return self._result

    def __await__(self) -> Generator[Any, None, Any]:
        yield from self._done_event.wait().__await__()
        return self.result



@dataclass
class ResultCaptureNursery:
    nursery: trio.Nursery

    def start_soon(self, f: AsyncFnType, *args: ArgType):
        return ResultCapture(self.nursery, f, args)


@asynccontextmanager
async def open_capturing_nursery():
    async with trio.open_nursery() as N:
        yield ResultCaptureNursery(N)

@arthur-tacca
Copy link
Author

@indigoviolet Thanks for the feedback!

These are all things I considered and rejected when I wrote this snippet, so my responses below are going to seem like I'm being negative. But actually I really value this discussion and I'm grateful for your comments, and these are all subjective so I understand that everything below is just my opinion, and yours is just as valid.

  • Nursery wrapper (your point 2): This makes the application code shorter but I don't think it makes it simpler (easier to read). Using normal nurseries is more composable and gives a more honest impression of how the implementation works – explicit is better than implicit. (At one point I didn't even have ResultWrapper.start_soon(nursery, foo) – the official API was that users would have to separately write out the two lines rw = ResultWrapper(foo) and nursery.start_soon(rw.run). In some ways, I still think that's be better, but even I admit it would be annoying.) That argument doesn't apply so much to the results-iterating wrapper you linked to because it really is doing a bit of work spanning multiple tasks, but even that would be simplified by attaching to an existing nursery IMO.
  • Awaitable result (part of your point 3): I'm strongly against any method that both waits and returns the result. That gives the impression that it's the best way to use the class, but 99% of the time it's best to just let the enclosing nursery finish and then access the result synchronously.
  • Whole object awaitable (other part of your point 3): this amplifies the previous point because it looks like the default way to use the class. Combined with the nursery wrapper you also get the danger that await nursery.start_soon(...) looks like it waits for the task to start when it actually waits for it to finish (see dabeaz/curio#342). (I'm glad you posted the code though - I found the implementation interesting!)
  • No run() method (your point 1): I see where you're coming from here, it does simplify the implementation a bit. But I like exposing it because it makes the class feel a bit less magical: the docs can say that res.start_soon() is equivalent to instantiating the class and calling nursery.start_soon(res.run), and run() just saves the result as a member. That feels very easy to understand to me. Also, in principle it's possible to use in cleverer ways (e.g. you could be passed as a callback to some other API), again there's some composability here.
  • Use of dataclass: This is interesting, it simplifies the class a bit, and I like that it makes args public (I had considered this before) but I have mixed feelings since this class is more than just data, it has side effects. At the very least I think I'd want to turn off the generated equality and hash methods.
  • Type hints: In general I like type hints but I don't understand AsyncFnType or ArgType hints. Also, I'm planning on generalising this to allow either anyio or raw Trio (with no hard dependency on either - just on sniffio) which would complicate things slightly.
  • Name of exception: This isn't something you asked about but I do like what you did. I was never happy with TaskWrappedException, it says how it works rather than the error situation it represents (I half-jokingly suggested TaskFinishedWithExceptionException before which says what the problem is!). TaskFailedException is a much better name all round.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment