Skip to content

Instantly share code, notes, and snippets.

@graingert
Created July 31, 2022 09:50
Show Gist options
  • Save graingert/5356401e196e3671abfa5e33c969a73e to your computer and use it in GitHub Desktop.
Save graingert/5356401e196e3671abfa5e33c969a73e to your computer and use it in GitHub Desktop.
running trio in asyncio and getting the current task context
import asyncio
import collections.abc
import contextlib
import functools
import types
import httpx
import sniffio
import trio
TRIO_DONE = object()
class WaitTaskRescheduled:
_asyncio_future_blocking = True
_add_done_callback = None
_abort_func = None
def __init__(self, add_done_callback, abort_func):
self._add_done_callback = add_done_callback
self._abort_func = abort_func
def cancel(self, *args, **kwargs):
v = self._abort_func
# break a reference cycle and only support level cancel
del self._abort_func
v()
# asyncio.Task.cancel calls:
#
# if self._fut_waiter is not None:
# if self._fut_waiter.cancel(msg=msg):
# # Leave self._fut_waiter; it may be a Task that
# # catches and ignores the cancellation so we may have
# # to cancel it again later.
# return True
# # It must be the case that self.__step is already scheduled.
# self._must_cancel = True
# self._cancel_message = msg
# fut_waiter (that's us) needs to return True otherwise task._must_cancel
# is set to True, which means when we wake up the task it will call
# coro.throw(CancelledError)!
return True
def get_loop(self):
return asyncio.get_running_loop()
def add_done_callback(self, fn, *, context):
v = self._add_done_callback
# break a reference cycle and detect multiple add_done_callbacks
del self._add_done_callback
if v is None:
raise AssertionError("only one task can listen to a Future at a time")
v(fn, context)
@types.coroutine
def _async_yield(v):
return (yield v)
async def get_context():
context = None
def abort_func():
pass
def add_done_callback(reschedule, context_):
nonlocal context
context = context_
asyncio.current_task().get_loop().call_soon(reschedule, NullFuture())
await _async_yield(
WaitTaskRescheduled(add_done_callback=add_done_callback, abort_func=abort_func)
)
return context
@collections.abc.Coroutine.register
class WrapCoro:
def __init__(self, coro, context):
self._coro = coro
self._context = context
def __await__(self):
return self
def __iter__(self):
return self
def __next__(self):
return self.send(None)
def throw(self, *exc_info):
result = self._context.run(self._coro.throw, *exc_info)
if result is TRIO_DONE:
raise StopIteration
return result
def send(self, v):
result = self._context.run(self._coro.send, v)
if result is TRIO_DONE:
raise StopIteration
return result
class NullFuture:
def result(self):
return None
class NullContext: # sniffio stores the current async library on the context and not a threadlocal
def run(self, fn, /, *args, **kwargs):
return fn(*args, **kwargs)
def done_callback(outcome, call_soon, callback, context):
del outcome # we don't need the outcome, it can only be None
call_soon(callback, NullFuture(), context=context)
@contextlib.asynccontextmanager
async def as_trio():
cancel_scope = trio.CancelScope()
# Revised 'done' callback: set a Future
async def trio_main(coro):
with cancel_scope:
return await coro
def add_done_callback(callback, context):
task = asyncio.current_task()
loop = task.get_loop()
loop.call_soon(
functools.partial(
trio.lowlevel.start_guest_run,
functools.partial(
trio_main, WrapCoro(task.get_coro(), context=NullContext())
),
run_sync_soon_not_threadsafe=loop.call_soon,
run_sync_soon_threadsafe=loop.call_soon_threadsafe,
done_callback=functools.partial(
done_callback,
call_soon=loop.call_soon,
callback=callback,
context=context,
),
)
)
# suspend the current task so we can use its coro
await _async_yield(
WaitTaskRescheduled(
add_done_callback=add_done_callback, abort_func=cancel_scope.cancel
)
)
try:
yield
finally:
# tell our WrapCoro that trio is done
await _async_yield(TRIO_DONE)
async def demo(client):
r = await client.get("https://google.com")
print(r)
async def main():
task = asyncio.current_task()
task.get_loop().call_later(1, task.cancel)
print(await get_context())
try:
async with as_trio():
print(sniffio.current_async_library())
async with httpx.AsyncClient() as client:
async with trio.open_nursery() as nursery:
nursery.start_soon(demo, client)
nursery.start_soon(demo, client)
await trio.sleep(10)
except trio.Cancelled:
print("cancelled")
print(await get_context())
print(sniffio.current_async_library())
async with httpx.AsyncClient() as client:
await asyncio.gather(demo(client), demo(client))
if __name__ == "__main__":
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment