Skip to content

Instantly share code, notes, and snippets.

@graingert
Last active July 31, 2022 09:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save graingert/d20fdaa41511c4cccb756259ee477444 to your computer and use it in GitHub Desktop.
Save graingert/d20fdaa41511c4cccb756259ee477444 to your computer and use it in GitHub Desktop.
import asyncio
import collections.abc
import contextlib
import functools
import types
import httpx
import sniffio
import trio
TRIO_DONE = object()
class DumbFuture:
_asyncio_future_blocking = True
_add_done_callback = None
_on_cancel = None
def __init__(self, add_done_callback, on_cancel):
self._add_done_callback = add_done_callback
self._on_cancel = on_cancel
def cancel(self, *args, **kwargs):
v = self._on_cancel
# break a reference cycle and only support level cancel
del self._on_cancel
v()
# asyncio.Task.cancel calls:
#
# if self._fut_waiter is not None:
# if self._fut_waiter.cancel(msg=msg):
# # Leave self._fut_waiter; it may be a Task that
# # catches and ignores the cancellation so we may have
# # to cancel it again later.
# return True
# # It must be the case that self.__step is already scheduled.
# self._must_cancel = True
# self._cancel_message = msg
# fut_waiter (that's us) needs to return True otherwise task._must_cancel
# is set to True, which means when we wake up the task it will call
# coro.throw(CancelledError)!
return True
def get_loop(self):
return asyncio.get_running_loop()
def add_done_callback(self, fn, *, context):
v = self._add_done_callback
# break a reference cycle and detect multiple add_done_callbacks
del self._add_done_callback
if v is None:
raise AssertionError("only one task can listen to a Future at a time")
v(fn, context)
@types.coroutine
def _async_yield(v):
return (yield v)
@collections.abc.Coroutine.register
class WrapCoro:
def __init__(self, coro, context):
self._coro = coro
self._context = context
def __await__(self):
return self
def __iter__(self):
return self
def __next__(self):
return self.send(None)
def throw(self, *exc_info):
result = self._context.run(self._coro.throw, *exc_info)
if result is TRIO_DONE:
raise StopIteration
return result
def send(self, v):
result = self._context.run(self._coro.send, v)
if result is TRIO_DONE:
raise StopIteration
return result
class NullFuture:
def result(self):
return None
class NullContext: # sniffio stores the current async library on the context and not a threadlocal
def run(self, fn, /, *args, **kwargs):
return fn(*args, **kwargs)
def done_callback(outcome, call_soon, callback, context):
del outcome # we don't need the outcome, it can only be None
call_soon(callback, NullFuture(), context=context)
@contextlib.asynccontextmanager
async def as_trio():
cancel_scope = trio.CancelScope()
# Revised 'done' callback: set a Future
async def trio_main(coro):
with cancel_scope:
return await coro
def add_done_callback(callback, context):
task = asyncio.current_task()
loop = task.get_loop()
loop.call_soon(
functools.partial(
trio.lowlevel.start_guest_run,
functools.partial(
trio_main, WrapCoro(task.get_coro(), context=NullContext())
),
run_sync_soon_not_threadsafe=loop.call_soon,
run_sync_soon_threadsafe=loop.call_soon_threadsafe,
done_callback=functools.partial(
done_callback,
call_soon=loop.call_soon,
callback=callback,
context=context,
),
)
)
# suspend the current task so we can use its coro
await _async_yield(
DumbFuture(
add_done_callback=add_done_callback, on_cancel=cancel_scope.cancel
)
)
try:
yield
finally:
# tell our WrapCoro that trio is done
await _async_yield(TRIO_DONE)
async def demo(client):
r = await client.get("https://google.com")
print(r)
async def main():
task = asyncio.current_task()
task.get_loop().call_later(1, task.cancel)
try:
async with as_trio():
print(sniffio.current_async_library())
async with httpx.AsyncClient() as client:
async with trio.open_nursery() as nursery:
nursery.start_soon(demo, client)
nursery.start_soon(demo, client)
await trio.sleep(10)
except trio.Cancelled:
print("cancelled")
print(sniffio.current_async_library())
async with httpx.AsyncClient() as client:
await asyncio.gather(demo(client), demo(client))
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment