Last active
November 28, 2018 03:35
-
-
Save njsmith/40b7b7f65e5f433789153c7b668ce643 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import deque, OrderedDict | |
import trio | |
class MultiGetQueue: | |
def __init__(self, max_size): | |
self._max_size = max_size | |
# {task: abort func} | |
# probably should make Task._abort_func public, or maybe even change | |
# the reschedule code to call the abort func in general? | |
self._get_wait = OrderedDict() | |
# {task: queued value} | |
self._put_wait = OrderedDict() | |
# invariants: | |
# if len(self._q) < self._max_size, then self._put_wait is empty | |
# if len(self._q) > 0, then self._get_wait is empty | |
self._q = deque() | |
async def put(self, obj): | |
await trio.hazmat.yield_if_cancelled() | |
if self._get_wait: | |
assert not self._q | |
task, abort_fn = self._get_wait.popitem(last=False) | |
abort_fn() | |
trio.hazmat.reschedule(task, trio.Value((self, obj))) | |
await trio.hazmat.yield_briefly_no_cancel() | |
elif len(self._q) < self._max_size: | |
self._q.append(obj) | |
await trio.hazmat.yield_briefly_no_cancel() | |
else: | |
task = trio.current_task() | |
self._put_wait[task] = obj | |
def abort_fn(_): | |
del self._put_wait[task] | |
return trio.hazmat.Abort.SUCCEEDED | |
await trio.hazmat.yield_indefinitely(abort_fn) | |
async def get(self): | |
_, value = await multi_get([self]) | |
return value | |
async def multi_get(queues): | |
# Returns (queue object, value gotten) | |
await trio.hazmat.yield_if_cancelled() | |
for queue in queues: | |
if queue._put_wait: | |
task, value = queue._put_wait.popitem(last=False) | |
# No need to check max_size, b/c we'll pop an item off again right | |
# below. | |
queue._q.append(value) | |
trio.hazmat.reschedule(task) | |
if queue._q: | |
value = queue._q.popleft() | |
await trio.hazmat.yield_briefly_no_cancel() | |
return value | |
# No queue had anything. | |
task = trio.current_task() | |
def abort_fn(_): | |
for queue in queues: | |
del queue._get_wait[task] | |
return trio.hazmat.Abort.SUCCEEDED | |
for queue in queues: | |
queue._get_wait[task] = abort_fn | |
return await trio.hazmat.yield_indefinitely(abort_fn) |
@sorcio: doh, yeah, those yields were totally in the wrong place. That's what happen when I write too fast :-). Fixed.
The above prioritizes queues in the order they're passed to multi_get
, which is what @parity3 specifically wanted for their use case :-), and then is FIFO-fair on each individual queue. Golang uses simple randomization.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Yields at lines 21, 27, 52 are suspicious. A concurrent task can pop/push at the same time.
A conceptual issue I encountered with multi-get semantics is fairness. It probably should be delegated to application code (e.g. shuffle
queues
before passing it to the function). But I'm trying to think whether it might be worth for the library to implement more sophisticated policies, like favoring the most saturated queue, or queues that have waiters.