Skip to content

Instantly share code, notes, and snippets.

@alexshpilkin
Created January 7, 2021 22:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alexshpilkin/af9ebf11fec6d38baecd870a455f139f to your computer and use it in GitHub Desktop.
Save alexshpilkin/af9ebf11fec6d38baecd870a455f139f to your computer and use it in GitHub Desktop.
Trio channels with priorities
from heapq import heappush, heappop
from math import inf
from trio import BrokenResourceError, ClosedResourceError, EndOfChannel, WouldBlock
from trio.abc import ReceiveChannel, SendChannel
from trio.lowlevel import ParkingLot, checkpoint, checkpoint_if_cancelled, cancel_shielded_checkpoint
from trio._channel import MemoryChannelStats
from trio._util import NoPublicConstructor
class MemoryChannelState:
__slots__ = ('data', 'max_buffer_size', 'number', 'open_send_channels',
'open_receive_channels', 'priority', 'receivers',
'senders')
def __init__(self, max_buffer_size, priority):
self.max_buffer_size = max_buffer_size
self.priority = priority
self.data = []
self.open_send_channels = 0
self.open_receive_channels = 0
self.senders = ParkingLot()
self.receivers = ParkingLot()
def statistics(self):
return MemoryChannelStats(current_buffer_used=len(self.data),
max_buffer_size=self.max_buffer_size,
open_send_channels=self.open_send_channels,
open_receive_channels=self.open_receive_channels,
tasks_waiting_send=self.senders.statistics().tasks_waiting,
tasks_waiting_receive=self.receivers.statistics().tasks_waiting)
class MemorySendChannel(SendChannel, metaclass=NoPublicConstructor):
__slots__ = ('_closed', '_state')
def __init__(self, state):
self._state = state
self._closed = False
state.open_send_channels += 1
def clone(self):
if self._closed:
raise ClosedResourceError
return MemorySendChannel._create(self._state)
def statistics(self):
return self._state.statistics()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def close(self):
if self._closed:
return
self._closed = True
self._state.open_send_channels -= 1
if not self._state_open_send_channels:
assert not self._state.senders
self._state.receivers.unpark_all()
async def aclose(self):
self.close()
await checkpoint()
def send_nowait(self, value, *, _could_block=True):
if self._closed:
raise ClosedResourceError
if not self._state.open_receive_channels:
raise BrokenResourceError
if len(self._state.data) >= self._state.max_buffer_size:
assert _could_block
raise WouldBlock
number = self._state.number
self._state.number = number + 1
heappush(self._state.data,
(self._state.priority(value), number, value))
if self._state.receivers:
assert len(self._state.data) == 1
self._state.receivers.unpark()
async def send(self, value):
await checkpoint_if_cancelled()
try:
self.send_nowait(value)
except WouldBlock:
pass
else:
await cancel_shielded_checkpoint()
return
await self._state.senders.park()
self.send_nowait(value, _could_block=False)
class MemoryReceiveChannel(ReceiveChannel, metaclass=NoPublicConstructor):
__slots__ = ('_closed', '_state')
def __init__(self, state):
self._state = state
self._closed = False
state.open_receive_channels += 1
def clone(self):
if self._closed:
raise ClosedResourceError
return MemoryReceiveChannel._create(self._state)
def statistics(self):
return self._state.statistics()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def close(self):
if self._closed:
return
self._closed = True
self._state.open_receive_channels -= 1
if not self._state.open_receive_channels:
assert not self._state.receivers
self._state.senders.unpark_all()
async def aclose(self):
self.close()
await checkpoint()
def receive_nowait(self, *, _could_block=True):
if self._closed:
raise ClosedResourceError
try:
_priority, _number, value = heappop(self._state.data)
except IndexError:
if not self._state.open_send_channels:
raise EndOfChannel
assert _could_block
raise WouldBlock
assert not self._state.receivers
if self._state.senders:
assert (len(self._state.data) ==
self._state.max_buffer_size - 1)
self._state.senders.unpark()
return value
async def receive(self):
await checkpoint_if_cancelled()
try:
value = self.receive_nowait()
except WouldBlock:
pass
else:
await cancel_shielded_checkpoint()
return value
await self._state.receivers.park()
return self.receive_nowait(_could_block=False)
def open_memory_channel(max_buffer_size, *, priority=None):
if priority is None:
priority = lambda x: 0
if max_buffer_size != inf and not isinstance(max_buffer_size, int):
raise TypeError("max_buffer_size must be an integer or math.inf")
if max_buffer_size < 0:
raise ValueError("max_buffer_size must be >= 0")
state = MemoryChannelState(max_buffer_size, priority)
return (MemorySendChannel._create(state),
MemoryReceiveChannel._create(state))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment