Skip to content

Instantly share code, notes, and snippets.

@basak
Created April 6, 2021 02:27
Show Gist options
  • Save basak/007da3fc2448300d037c9bb008cc5e80 to your computer and use it in GitHub Desktop.
Save basak/007da3fc2448300d037c9bb008cc5e80 to your computer and use it in GitHub Desktop.
Trio broadcast implementation with "slowest consumer" backpressure

I've been using this code "in production" for a while now. It dates back to https://groups.google.com/g/python-tulip/c/J7tCcBU5TPA/m/NM7iBhhhEAAJ except that I converted it to Trio a while ago. It is intended to be lossless - if desired, you can ensure to catch all messages since you start to listen, without losing anything. Producers are blocked on send() until the slowest consumer has received the message.

Since new consumers won't receive messages from before they began to listen, the point at which a consumer "begins listening" is important. This happens when the async iterator is created - ie. when the for loop runs the implicit aiter(). If you do this as the first thing in a coroutine, you might expect all message following a nursery.start_soon() call starting that coroutine to be picked up. But in practice, the for loop won't run the implicit aiter() until some time later, and so you won't see messages sent prior to that point. To avoid this, you must all aiter() yourself and pass that in, or use nursery.start() and arrange for the coroutine to call task_status.started().

See value.py for examples of use.

#!/usr/bin/python3
import weakref
import trio
class EventQueueReader:
def __init__(self, parent=None):
self._parent = parent
self._send_c, self._receive_c = trio.open_memory_channel(0)
self._receive_c._aioevent_EventQueueReader = self
self._receive_c_aiter = self._receive_c.__aiter__()
async def _send(self, item, task_status=trio.TASK_STATUS_IGNORED):
# With no channel buffer, this call will block until the reader has
# received it. This means that generally for EventQueue a send blocks
# until all readers have received the event.
# This may be somewhat error prone as a single buggy reader could cause
# a send to block indefinitely. This was previously a RuntimeError, but
# it turns out not to be possible in the general case for us to know
# that all readers won't block, because a reader may need to run a
# checkpoint before it blocks waiting on the queue again and in this
# case we can't detect if it will read from the queue again or not. See
# test_reader_takes_a_while() for an example of this case.
# The code that raised the RuntimeError previously but cannot work in
# the general case is as follows:
#if not self._send_c.statistics().tasks_waiting_receive:
# raise RuntimeError("Can't write to a reader that isn't listening")
task_status.started()
await self._send_c.send(item)
def __aiter__(self):
return self
async def __anext__(self):
return await self._receive_c_aiter.__anext__()
class EventQueue:
def __init__(self):
self._readers = weakref.WeakSet()
def reader(self):
# the EventQueueReader should hold a reference to its parent so this
# EventQueue continues to exist for as long as the reader exists. This
# is needed for KeyedEventQueue, for example, since otherwise the
# EventQueue disappears even if a reader exists due to it only being
# part of a WeakValueDictionary. Then a future lookup for the same key
# creates a new EventQueue, so incoming events do not get to the
# reader.
new_reader = EventQueueReader(parent=self)
self._readers.add(new_reader)
return new_reader
async def send(self, item, task_status=trio.TASK_STATUS_IGNORED):
async with trio.open_nursery() as nursery:
for reader in self._readers:
nursery.start_soon(reader._send, item)
task_status.started()
def __aiter__(self):
return self.reader().__aiter__()
class KeyedEventQueue:
def __init__(self, nursery, parent, key_func, data_func=lambda x: x,
filter_func=lambda x: True):
self.parent_reader = parent.reader()
self.key_func = key_func
self.data_func = data_func
self.filter_func = filter_func
self._queues = weakref.WeakValueDictionary()
nursery.start_soon(self._process_incoming)
async def _process_incoming(self):
async for item in self.parent_reader:
if not self.filter_func(item):
continue
# XXX filter_func and data_func exceptions need handling gracefully
item_key = self.key_func(item)
item_value = self.data_func(item)
await self[item_key].send(item_value)
def __getitem__(self, k):
try:
return self._queues[k]
except KeyError:
new_queue = EventQueue()
self._queues[k] = new_queue
return new_queue
class FilteredEventQueue:
def __init__(self, nursery, parent, filter_func, data_func=lambda x: x):
self.parent_reader = parent.reader()
self.filter_func = filter_func
self.data_func = data_func
self._queue = EventQueue()
nursery.start_soon(self._process_incoming)
async def _process_incoming(self):
async for item in self.parent_reader:
# XXX filter_func and data_func exceptions need handling gracefully
if self.filter_func(item):
await self._queue.send(self.data_func(item))
def reader(self):
return self._queue.reader()
def __aiter__(self):
return self.reader()
import heapq
import itertools
import operator
import time
import dateutil
import trio
import aioevent
class BlockingAsyncIterable:
def __aiter__(self):
return self
async def __anext__(self):
await trio.sleep_forever()
class ConstantValue:
def __init__(self, value):
self._value = value
self.values = BlockingAsyncIterable()
@property
def value(self):
return self._value
class ReadOnlyDynamicValue:
def __init__(self):
self.values = aioevent.EventQueue()
async def notify_value_changed(self):
await self.values.send(self.value)
def get_values_attribute(obj):
try:
return obj.values
except AttributeError:
return obj
class DynamicValue(ReadOnlyDynamicValue):
def __init__(self, nursery):
super().__init__()
self.__nursery = nursery
self.__value_copy_nursery = None
self.__source = None
def set_initial_value(self, initial_value):
self.__value = initial_value
async def set_value(self, new_value):
self.__value = new_value
await self.notify_value_changed()
def get_value(self):
return self.__value
@property
def value(self):
return self.get_value()
@property
def source(self):
return self.__source
@source.setter
def source(self, values):
if self.__value_copy_nursery:
self.__value_copy_nursery.cancel_scope.cancel()
self.__value_copy_nursery = None
self.__source = get_values_attribute(values)
if self.__source is None:
return
source_values_aiter = self.__source.__aiter__()
try:
initial_source_value = values.value
except AttributeError:
initial_source_value = None
self.__nursery.start_soon(
self.__value_copy,
self.__source,
source_values_aiter,
initial_source_value,
)
async def __value_copy(self, values, values_aiter, initial_source_value):
if values is not self.__source:
return
if initial_source_value is not None:
await self.set_value(initial_source_value)
async with trio.open_nursery() as nursery:
self.__value_copy_nursery = nursery
async for new_value in values_aiter:
await self.set_value(new_value)
async def as_any_values_change(nursery, valuess):
# It is necessary for a nursery to be provided by the caller here, rather
# than us creating one ourselves, due to this being an async generator and
# yielding from inside a nursery block isn't permitted by trio:
# https://github.com/python-trio/trio/issues/264
sender, receiver = trio.open_memory_channel(0)
async def send_values_to_queue(values, task_status=trio.TASK_STATUS_IGNORED):
task_status.started()
async for value in get_values_attribute(values):
await sender.send(value)
for values in valuess:
await nursery.start(send_values_to_queue, values)
# all readers are started and blocking at the async for now
async for value in receiver:
yield value
class DerivedValue(ReadOnlyDynamicValue):
def __init__(self, nursery, valuess=None):
super().__init__()
readers = [get_values_attribute(values).__aiter__() for values in valuess]
changes = as_any_values_change(nursery, readers)
nursery.start_soon(self.__follow_value_changes, changes)
self.value = self.calculate()
async def __follow_value_changes(self, changes):
async for change in changes:
old_value = self.value
new_value = self.calculate()
if new_value != old_value:
self.value = new_value
await self.notify_value_changed()
class MaxValue(DerivedValue):
def __init__(self, nursery, sources):
self.sources = sources
super().__init__(nursery, sources)
def calculate(self):
return max(v.value for v in self.sources)
class ValueRatio(DerivedValue):
def __init__(self, nursery, a, b):
self.a = a
self.b = b
super().__init__(nursery, [self.a, self.b])
def calculate(self):
return self.a.value / self.b.value
class OrValue(DerivedValue):
def __init__(self, nursery, a, b):
self.a = a
self.b = b
super().__init__(nursery, [self.a, self.b])
def calculate(self):
return self.a.value or self.b.value
class ScheduledValues(ReadOnlyDynamicValue):
def __init__(self, nursery, sequence, initial_value=False, time_func=time.time, sleep_func=trio.sleep):
super().__init__()
self._value = initial_value
self._iterator = iter(sequence)
self._time_func = time_func
self._sleep_func = sleep_func
nursery.start_soon(self._follow_schedule)
@staticmethod
def datetime_to_timestamp(datetime, tzinfo=None):
# If tzinfo is supplied then it will be used and datetime must be
# naive. If not supplied then the local time zone is used if the
# datetime is naive, otherwise the datetime's own time zone is used.
if datetime.tzinfo is None:
tzinfo_to_use = tzinfo or dateutil.tz.tzlocal()
return datetime.replace(tzinfo=tzinfo_to_use).timestamp()
else:
assert tzinfo is None
return datetime.timestamp()
@classmethod
def sequence_from_rrule(cls, rrule, value, tzinfo=None):
return (
(cls.datetime_to_timestamp(datetime, tzinfo), value)
for datetime in rrule
)
@classmethod
def sequence_from_rrule_value_pairs(cls, rrule_value_pairs, tzinfo=None):
sequences = [
(
(cls.datetime_to_timestamp(datetime, tzinfo), v)
for datetime, v in zip(rrule, itertools.repeat(value))
)
for rrule, value in rrule_value_pairs
]
return heapq.merge(*sequences, key=operator.itemgetter(0))
@classmethod
def from_rrule(cls, nursery, rrule, value, **kwargs):
sequence = cls.sequence_from_rrule(rrule, value)
return cls(nursery, sequence, **kwargs)
@classmethod
def from_rrule_value_pairs(cls, nursery, rrule_value_pairs, **kwargs):
sequence = cls.sequence_from_rrule_value_pairs(rrule_value_pairs)
return cls(nursery, sequence, **kwargs)
@property
def value(self):
return self._value
async def _follow_schedule(self):
while True:
now = self._time_func()
next_timestamp, next_value = next(self._iterator)
if next_timestamp < now:
# skip through to an event that is in the future and instead of
# flapping our value set it to the most recent event that took
# place in the past
while next_timestamp < now:
final_past_value = next_value
next_timestamp, next_value = next(self._iterator)
self._value = final_past_value
await self.notify_value_changed()
delay = next_timestamp - now
await self._sleep_func(delay)
self._value = next_value
await self.notify_value_changed()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment