-
-
Save richardsheridan/8f803a5a15831840846083a6cbcefbea to your computer and use it in GitHub Desktop.
Universal cross-thread unbuffered queue for trio, asyncio, and threads
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Rough draft of a Queue object that can be used simultaneously from | |
# sync threads + *multiple* trio and asyncio threads, all at once. | |
# | |
# If you don't have multiple threads each doing their own separate calls to Xio.run, | |
# then don't use this; there are simpler solutions. This was mostly an exercise to | |
# figure out if and how this could be done. | |
# | |
# Currently, the test will provide 94% coverage given sufficient timeout. The | |
# remaining are (apparently rare) races and the durable aio shielding. | |
import random | |
import threading | |
import time | |
from collections import OrderedDict | |
from functools import partial | |
import queue | |
class CrossThreadUnbufferedFIFOQueue: | |
def __init__(self): | |
# Locks critical sections where getters and putters could block simultaneously | |
self._lock = threading.Lock() | |
# Used as FIFO queues; value is always None | |
self._putters = OrderedDict() | |
self._getters = OrderedDict() | |
def put_nowait(self, value): | |
try: | |
getter, _ = self._getters.popitem(last=False) | |
except KeyError: | |
raise queue.Full from None | |
else: | |
getter(value) | |
return | |
def get_nowait(self): | |
try: | |
putter, _ = self._putters.popitem(last=False) | |
except KeyError: | |
raise queue.Empty from None | |
else: | |
return putter() | |
def put_sync(self, value, timeout=None): | |
try: | |
self.put_nowait(value) | |
except queue.Full: | |
pass | |
else: | |
return | |
waker = threading.Lock() | |
waker.acquire() | |
def putter(): | |
waker.release() | |
return value | |
with self._lock: | |
try: | |
self.put_nowait(value) | |
except queue.Full: | |
pass | |
else: | |
return | |
self._putters[putter] = None | |
if waker.acquire(timeout=-1 if timeout is None else timeout): | |
return | |
try: | |
del self._putters[putter] | |
except KeyError: | |
return | |
else: | |
raise queue.Full | |
def get_sync(self, timeout=None): | |
try: | |
return self.get_nowait() | |
except queue.Empty: | |
pass | |
waker = threading.Lock() | |
waker.acquire() | |
value_shared = None | |
def getter(value): | |
nonlocal value_shared | |
value_shared = value | |
waker.release() | |
with self._lock: | |
try: | |
return self.get_nowait() | |
except queue.Empty: | |
pass | |
self._getters[getter] = None | |
if waker.acquire(timeout=-1 if timeout is None else timeout): | |
return value_shared | |
try: | |
del self._getters[getter] | |
except KeyError: | |
waker.acquire() # only briefly blocking | |
return value_shared | |
else: | |
raise queue.Empty | |
async def put(self, value): | |
from sniffio import current_async_library | |
async_library = current_async_library() | |
if async_library == "trio": | |
await self.put_trio(value) | |
elif async_library == "asyncio": | |
await self.put_aio(value) | |
else: | |
raise RuntimeError("Unsupported async library:", async_library) | |
async def get(self): | |
from sniffio import current_async_library | |
async_library = current_async_library() | |
if async_library == "trio": | |
return await self.get_trio() | |
elif async_library == "asyncio": | |
return await self.get_aio() | |
else: | |
raise RuntimeError("Unsupported async library:", async_library) | |
async def put_trio(self, value): | |
try: | |
self.put_nowait(value) | |
except queue.Full: | |
pass | |
else: | |
return | |
import trio | |
token = trio.lowlevel.current_trio_token() | |
task = trio.lowlevel.current_task() | |
def putter(): | |
token.run_sync_soon(trio.lowlevel.reschedule, task) | |
return value | |
with self._lock: | |
try: | |
self.put_nowait(value) | |
except queue.Full: | |
pass | |
else: | |
return | |
self._putters[putter] = None | |
def abort_fn(_): | |
try: | |
del self._putters[putter] | |
except KeyError: | |
return trio.lowlevel.Abort.FAILED | |
else: | |
return trio.lowlevel.Abort.SUCCEEDED | |
await trio.lowlevel.wait_task_rescheduled(abort_fn) | |
async def get_trio(self): | |
try: | |
return self.get_nowait() | |
except queue.Empty: | |
pass | |
import trio | |
token = trio.lowlevel.current_trio_token() | |
task = trio.lowlevel.current_task() | |
def getter(value): | |
import outcome | |
o = outcome.Value(value) | |
token.run_sync_soon(trio.lowlevel.reschedule, task, o) | |
with self._lock: | |
try: | |
return self.get_nowait() | |
except queue.Empty: | |
pass | |
self._getters[getter] = None | |
def abort_fn(_): | |
try: | |
del self._getters[getter] | |
except KeyError: | |
return trio.lowlevel.Abort.FAILED | |
else: | |
return trio.lowlevel.Abort.SUCCEEDED | |
return await trio.lowlevel.wait_task_rescheduled(abort_fn) | |
async def put_aio(self, value): | |
try: | |
self.put_nowait(value) | |
except queue.Full: | |
pass | |
else: | |
return | |
import asyncio | |
loop = asyncio.get_running_loop() | |
fut = loop.create_future() | |
def putter(): | |
loop.call_soon_threadsafe(fut.set_result, None) | |
return value | |
with self._lock: | |
try: | |
self.put_nowait(value) | |
except queue.Full: | |
pass | |
else: | |
return | |
self._putters[putter] = None | |
try: | |
return await asyncio.shield(fut) | |
except asyncio.CancelledError: | |
try: | |
del self._putters[putter] | |
except KeyError: | |
pass | |
else: | |
raise | |
# durable shielded wait: https://stackoverflow.com/a/39692502/4504950 | |
while True: | |
try: | |
return await asyncio.shield(fut) | |
except asyncio.CancelledError: | |
continue | |
async def get_aio(self): | |
try: | |
return self.get_nowait() | |
except queue.Empty: | |
pass | |
import asyncio | |
loop = asyncio.get_running_loop() | |
fut = loop.create_future() | |
def getter(value): | |
loop.call_soon_threadsafe(fut.set_result, value) | |
with self._lock: | |
try: | |
return self.get_nowait() | |
except queue.Empty: | |
pass | |
self._getters[getter] = None | |
try: | |
return await asyncio.shield(fut) | |
except asyncio.CancelledError: | |
try: | |
del self._getters[getter] | |
except KeyError: | |
pass | |
else: | |
raise | |
# durable shielded wait: https://stackoverflow.com/a/39692502/4504950 | |
while True: | |
try: | |
return await asyncio.shield(fut) | |
except asyncio.CancelledError: | |
continue | |
def put_thread(q): | |
while True: | |
value = random.random() * MAX_VAL | |
time.sleep(value) | |
print("thread putting", value) | |
try: | |
q.put_sync(value, timeout=random.random() * MAX_VAL) | |
except queue.Full: | |
print("thread didn't put", value) | |
else: | |
print("thread put", value) | |
def get_thread(q): | |
while True: | |
print("thread getting") | |
try: | |
value = q.get_sync(timeout=random.random() * MAX_VAL) | |
except queue.Empty: | |
print("thread didn't get") | |
time.sleep(random.random() * MAX_VAL) | |
else: | |
print("thread got", value) | |
time.sleep(value) | |
async def put_trio(q): | |
import trio | |
while True: | |
value = random.random() * MAX_VAL | |
await trio.sleep(value) | |
print("trio putting", value) | |
with trio.move_on_after(random.random() * MAX_VAL) as cs: | |
await q.put(value) | |
if cs.cancelled_caught: | |
print("trio didn't put", value) | |
else: | |
print("trio put", value) | |
async def get_trio(q): | |
import trio | |
while True: | |
print("trio getting") | |
with trio.move_on_after(random.random() * MAX_VAL) as cs: | |
value = await q.get() | |
if cs.cancelled_caught: | |
print("trio didn't get") | |
await trio.sleep(random.random() * MAX_VAL) | |
else: | |
print("trio got", value) | |
await trio.sleep(value) | |
async def put_aio(q): | |
import asyncio | |
while True: | |
value = random.random() * MAX_VAL | |
await asyncio.sleep(value) | |
print("aio putting", value) | |
try: | |
await asyncio.wait_for(q.put(value), random.random() * MAX_VAL) | |
except asyncio.TimeoutError: | |
print("aio didn't put", value) | |
else: | |
print("aio put", value) | |
async def get_aio(q): | |
import asyncio | |
while True: | |
print("aio getting") | |
try: | |
value = await asyncio.wait_for(q.get(), random.random() * MAX_VAL) | |
except asyncio.TimeoutError: | |
print("aio didn't get") | |
await asyncio.sleep(random.random() * MAX_VAL) | |
else: | |
print("aio got", value) | |
await asyncio.sleep(value) | |
async def test_CrossThreadUnbufferedFIFOQueue(): | |
import trio, asyncio | |
q = CrossThreadUnbufferedFIFOQueue() | |
async with trio.open_nursery() as nursery: | |
nursery.start_soon( | |
partial(trio.to_thread.run_sync, put_thread, q, cancellable=True) | |
) | |
nursery.start_soon( | |
partial(trio.to_thread.run_sync, get_thread, q, cancellable=True) | |
) | |
nursery.start_soon(put_trio, q) | |
nursery.start_soon(get_trio, q) | |
nursery.start_soon( | |
partial(trio.to_thread.run_sync, trio.run, put_trio, q, cancellable=True) | |
) | |
nursery.start_soon( | |
partial(trio.to_thread.run_sync, trio.run, get_trio, q, cancellable=True) | |
) | |
nursery.start_soon( | |
partial(trio.to_thread.run_sync, asyncio.run, put_aio(q), cancellable=True) | |
) | |
nursery.start_soon( | |
partial(trio.to_thread.run_sync, asyncio.run, get_aio(q), cancellable=True) | |
) | |
nursery.cancel_scope.deadline = trio.current_time() + OVERALL_TIMEOUT | |
MAX_VAL = 0.03 | |
OVERALL_TIMEOUT = 5.0 | |
if __name__ == "__main__": | |
from trio import run | |
run(test_CrossThreadUnbufferedFIFOQueue) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment