Skip to content

Instantly share code, notes, and snippets.

@richardsheridan
Forked from njsmith/universal-trio-queue.py
Last active November 7, 2022 04:09
Show Gist options
  • Save richardsheridan/8f803a5a15831840846083a6cbcefbea to your computer and use it in GitHub Desktop.
Save richardsheridan/8f803a5a15831840846083a6cbcefbea to your computer and use it in GitHub Desktop.
Universal cross-thread unbuffered queue for trio, asyncio, and threads
# Rough draft of a Queue object that can be used simultaneously from
# sync threads + *multiple* trio and asyncio threads, all at once.
#
# If you don't have multiple threads each doing their own separate calls to Xio.run,
# then don't use this; there are simpler solutions. This was mostly an exercise to
# figure out if and how this could be done.
#
# Currently, the test will provide 94% coverage given sufficient timeout. The
# remaining are (apparently rare) races and the durable aio shielding.
import random
import threading
import time
from collections import OrderedDict
from functools import partial
import queue
class CrossThreadUnbufferedFIFOQueue:
def __init__(self):
# Locks critical sections where getters and putters could block simultaneously
self._lock = threading.Lock()
# Used as FIFO queues; value is always None
self._putters = OrderedDict()
self._getters = OrderedDict()
def put_nowait(self, value):
try:
getter, _ = self._getters.popitem(last=False)
except KeyError:
raise queue.Full from None
else:
getter(value)
return
def get_nowait(self):
try:
putter, _ = self._putters.popitem(last=False)
except KeyError:
raise queue.Empty from None
else:
return putter()
def put_sync(self, value, timeout=None):
try:
self.put_nowait(value)
except queue.Full:
pass
else:
return
waker = threading.Lock()
waker.acquire()
def putter():
waker.release()
return value
with self._lock:
try:
self.put_nowait(value)
except queue.Full:
pass
else:
return
self._putters[putter] = None
if waker.acquire(timeout=-1 if timeout is None else timeout):
return
try:
del self._putters[putter]
except KeyError:
return
else:
raise queue.Full
def get_sync(self, timeout=None):
try:
return self.get_nowait()
except queue.Empty:
pass
waker = threading.Lock()
waker.acquire()
value_shared = None
def getter(value):
nonlocal value_shared
value_shared = value
waker.release()
with self._lock:
try:
return self.get_nowait()
except queue.Empty:
pass
self._getters[getter] = None
if waker.acquire(timeout=-1 if timeout is None else timeout):
return value_shared
try:
del self._getters[getter]
except KeyError:
waker.acquire() # only briefly blocking
return value_shared
else:
raise queue.Empty
async def put(self, value):
from sniffio import current_async_library
async_library = current_async_library()
if async_library == "trio":
await self.put_trio(value)
elif async_library == "asyncio":
await self.put_aio(value)
else:
raise RuntimeError("Unsupported async library:", async_library)
async def get(self):
from sniffio import current_async_library
async_library = current_async_library()
if async_library == "trio":
return await self.get_trio()
elif async_library == "asyncio":
return await self.get_aio()
else:
raise RuntimeError("Unsupported async library:", async_library)
async def put_trio(self, value):
try:
self.put_nowait(value)
except queue.Full:
pass
else:
return
import trio
token = trio.lowlevel.current_trio_token()
task = trio.lowlevel.current_task()
def putter():
token.run_sync_soon(trio.lowlevel.reschedule, task)
return value
with self._lock:
try:
self.put_nowait(value)
except queue.Full:
pass
else:
return
self._putters[putter] = None
def abort_fn(_):
try:
del self._putters[putter]
except KeyError:
return trio.lowlevel.Abort.FAILED
else:
return trio.lowlevel.Abort.SUCCEEDED
await trio.lowlevel.wait_task_rescheduled(abort_fn)
async def get_trio(self):
try:
return self.get_nowait()
except queue.Empty:
pass
import trio
token = trio.lowlevel.current_trio_token()
task = trio.lowlevel.current_task()
def getter(value):
import outcome
o = outcome.Value(value)
token.run_sync_soon(trio.lowlevel.reschedule, task, o)
with self._lock:
try:
return self.get_nowait()
except queue.Empty:
pass
self._getters[getter] = None
def abort_fn(_):
try:
del self._getters[getter]
except KeyError:
return trio.lowlevel.Abort.FAILED
else:
return trio.lowlevel.Abort.SUCCEEDED
return await trio.lowlevel.wait_task_rescheduled(abort_fn)
async def put_aio(self, value):
try:
self.put_nowait(value)
except queue.Full:
pass
else:
return
import asyncio
loop = asyncio.get_running_loop()
fut = loop.create_future()
def putter():
loop.call_soon_threadsafe(fut.set_result, None)
return value
with self._lock:
try:
self.put_nowait(value)
except queue.Full:
pass
else:
return
self._putters[putter] = None
try:
return await asyncio.shield(fut)
except asyncio.CancelledError:
try:
del self._putters[putter]
except KeyError:
pass
else:
raise
# durable shielded wait: https://stackoverflow.com/a/39692502/4504950
while True:
try:
return await asyncio.shield(fut)
except asyncio.CancelledError:
continue
async def get_aio(self):
try:
return self.get_nowait()
except queue.Empty:
pass
import asyncio
loop = asyncio.get_running_loop()
fut = loop.create_future()
def getter(value):
loop.call_soon_threadsafe(fut.set_result, value)
with self._lock:
try:
return self.get_nowait()
except queue.Empty:
pass
self._getters[getter] = None
try:
return await asyncio.shield(fut)
except asyncio.CancelledError:
try:
del self._getters[getter]
except KeyError:
pass
else:
raise
# durable shielded wait: https://stackoverflow.com/a/39692502/4504950
while True:
try:
return await asyncio.shield(fut)
except asyncio.CancelledError:
continue
def put_thread(q):
while True:
value = random.random() * MAX_VAL
time.sleep(value)
print("thread putting", value)
try:
q.put_sync(value, timeout=random.random() * MAX_VAL)
except queue.Full:
print("thread didn't put", value)
else:
print("thread put", value)
def get_thread(q):
while True:
print("thread getting")
try:
value = q.get_sync(timeout=random.random() * MAX_VAL)
except queue.Empty:
print("thread didn't get")
time.sleep(random.random() * MAX_VAL)
else:
print("thread got", value)
time.sleep(value)
async def put_trio(q):
import trio
while True:
value = random.random() * MAX_VAL
await trio.sleep(value)
print("trio putting", value)
with trio.move_on_after(random.random() * MAX_VAL) as cs:
await q.put(value)
if cs.cancelled_caught:
print("trio didn't put", value)
else:
print("trio put", value)
async def get_trio(q):
import trio
while True:
print("trio getting")
with trio.move_on_after(random.random() * MAX_VAL) as cs:
value = await q.get()
if cs.cancelled_caught:
print("trio didn't get")
await trio.sleep(random.random() * MAX_VAL)
else:
print("trio got", value)
await trio.sleep(value)
async def put_aio(q):
import asyncio
while True:
value = random.random() * MAX_VAL
await asyncio.sleep(value)
print("aio putting", value)
try:
await asyncio.wait_for(q.put(value), random.random() * MAX_VAL)
except asyncio.TimeoutError:
print("aio didn't put", value)
else:
print("aio put", value)
async def get_aio(q):
import asyncio
while True:
print("aio getting")
try:
value = await asyncio.wait_for(q.get(), random.random() * MAX_VAL)
except asyncio.TimeoutError:
print("aio didn't get")
await asyncio.sleep(random.random() * MAX_VAL)
else:
print("aio got", value)
await asyncio.sleep(value)
async def test_CrossThreadUnbufferedFIFOQueue():
import trio, asyncio
q = CrossThreadUnbufferedFIFOQueue()
async with trio.open_nursery() as nursery:
nursery.start_soon(
partial(trio.to_thread.run_sync, put_thread, q, cancellable=True)
)
nursery.start_soon(
partial(trio.to_thread.run_sync, get_thread, q, cancellable=True)
)
nursery.start_soon(put_trio, q)
nursery.start_soon(get_trio, q)
nursery.start_soon(
partial(trio.to_thread.run_sync, trio.run, put_trio, q, cancellable=True)
)
nursery.start_soon(
partial(trio.to_thread.run_sync, trio.run, get_trio, q, cancellable=True)
)
nursery.start_soon(
partial(trio.to_thread.run_sync, asyncio.run, put_aio(q), cancellable=True)
)
nursery.start_soon(
partial(trio.to_thread.run_sync, asyncio.run, get_aio(q), cancellable=True)
)
nursery.cancel_scope.deadline = trio.current_time() + OVERALL_TIMEOUT
MAX_VAL = 0.03
OVERALL_TIMEOUT = 5.0
if __name__ == "__main__":
from trio import run
run(test_CrossThreadUnbufferedFIFOQueue)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment