Skip to content

Instantly share code, notes, and snippets.

@decentral1se
Last active May 17, 2020 18:41
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save decentral1se/42adc1f80f17655be6eb1f4a73ad7f0b to your computer and use it in GitHub Desktop.
Save decentral1se/42adc1f80f17655be6eb1f4a73ad7f0b to your computer and use it in GitHub Desktop.
Mux/demux example with trio (first steps...)
"""Mux/demux protocol toy example with Trio.
Naively done and just about working! Probably buggy. There is not much by way
of error handling however it does showcase an approach for multiplexing a
streamed connection with multiple channels. All messages that are sent over the
stream are tagged with a channel ID. Each Protocol and Channel instance have
their own incoming and outgoing queue.
I'm trying to work towards having a robust mux/demux implementation and need
some help. Any comments on this would be greatly appreciated. Hopefully this
could also be helpful to other Trio beginners.
Inspired by https://github.com/python-trio/trio/issues/467.
"""
from contextlib import asynccontextmanager
from random import randint
from string import ascii_letters
from typing import Dict
import attr
from trio import (
MemoryReceiveChannel,
MemorySendChannel,
open_memory_channel,
open_nursery,
run,
sleep,
)
from trio.abc import Stream
from trio.testing import memory_stream_pair
@attr.s(auto_attribs=True)
class Channel:
id: int
stream: Stream
protocol: "Protocol"
in_send: MemorySendChannel
in_recv: MemoryReceiveChannel
out_send: MemorySendChannel
out_recv: MemoryReceiveChannel
# user facing API send method
async def send(self, message):
await self.out_send.send(message)
# background task which receives messages from the incoming message queue
# and simply prints out a confirmation. This establishes that the mux/demux
# machinery on the protocol layer has been successful
async def recv_task(self):
while True:
message = await self.in_recv.receive()
dec = message.decode()
pid = self.protocol.id
print(f'P{pid}::C{self.id} -> received {dec}')
@attr.s(auto_attribs=True)
class Protocol:
id: int
stream: Stream
in_send: MemorySendChannel
in_recv: MemoryReceiveChannel
out_send: MemorySendChannel
out_recv: MemoryReceiveChannel
BUFFER: int = 15
MESSAGE_LENGTH: int = 2
channels: Dict[int, Channel] = attr.Factory(dict)
# background task which simulates a keep alive message which showcases how
# we can send mesasges to the protocol level outgoing memory channel (which
# is also used by the underyling channels)
async def keep_alive_task(self):
KEEP_ALIVE_MESSAGE = b'KA'
while True:
await sleep(randint(1, 5))
await self.out_send.send(KEEP_ALIVE_MESSAGE)
# background task which iterates through all underyling channels and
# receives messages from the channels outgoing memory channel and then
# sends them to the outgoing protocol memory channel so that the send_task
# can handle them
async def mux_task(self):
while True:
for key in self.channels.copy():
channel = self.channels[key]
msg = await channel.out_recv.receive()
await self.out_send.send(msg)
print(f'P{self.id} -> mux {msg.decode()} for C{channel.id}')
await sleep(1)
# background task which receives messages from the incoming memory channel
# decodes them, and pushes them to the underyling channel incoming memory
# channel
async def demux_task(self):
while True:
response = await self.in_recv.receive()
id, msg = bytes(response[0:1]), bytes(response[1:2])
channel = self.channels[int(id.decode())]
await channel.in_send.send(msg)
print(f"P{self.id} -> demux {msg.decode()} for C{id.decode()}")
# background task which receives messages from the outgoing protocol memory
# channel and sends them out of over the stream to the other side of the
# wire
async def send_task(self):
while True:
message = await self.out_recv.receive()
await self.stream.send_all(message)
print(f'P{self.id} -> sent {message.decode()}')
# background task which receives messages from the wire via the stream and
# then sends them to the incoming protocol memory channel so that the
# demuxing can take place
async def recv_task(self):
while True:
response = await self.stream.receive_some(self.MESSAGE_LENGTH)
if response == b'KA':
continue
await self.in_send.send(response)
# context which opens a new underyling channel with its own incoming and
# outgoing queue. Each channel re-uses the protocol level stream for
# sending messages. The channel does not directly call send_all on the
# stream but sends messages to the outgoing queue which is then processed
# by a background task on the protocol level
@asynccontextmanager
async def channel(self, id: int):
in_send, in_recv = open_memory_channel(self.BUFFER) # type: ignore
out_send, out_recv = open_memory_channel(self.BUFFER) # type: ignore
async with open_nursery() as nursery:
async with in_send, in_recv, out_send, out_recv:
channel = Channel(
id=id,
stream=self.stream,
protocol=self,
in_send=in_send,
in_recv=in_recv,
out_send=out_send,
out_recv=out_recv,
)
nursery.start_soon(channel.recv_task)
self.channels[id] = channel
yield channel
async def left_side_protocol(nursery, stream):
in_send, in_recv = open_memory_channel(15)
out_send, out_recv = open_memory_channel(15)
async with in_send, out_recv, out_send, out_recv:
protocol = Protocol(
id=1,
stream=stream,
in_send=in_send,
in_recv=in_recv,
out_send=out_send,
out_recv=out_recv,
)
nursery.start_soon(protocol.recv_task)
nursery.start_soon(protocol.send_task)
nursery.start_soon(protocol.mux_task)
nursery.start_soon(protocol.demux_task)
nursery.start_soon(protocol.keep_alive_task)
async def channel(id):
async with protocol.channel(id=id) as channel:
for letter in ascii_letters:
message = (b'%d' % id) + letter.encode()
await channel.send(message)
await sleep(randint(0, 5))
nursery.start_soon(lambda: channel(id=1))
nursery.start_soon(lambda: channel(id=2))
nursery.start_soon(lambda: channel(id=3))
await sleep(60) # run for a minute
async def right_side_protocol(nursery, stream):
in_send, in_recv = open_memory_channel(15)
out_send, out_recv = open_memory_channel(15)
async with in_send, out_recv, out_send, out_recv:
protocol = Protocol(
id=2,
stream=stream,
in_send=in_send,
in_recv=in_recv,
out_send=out_send,
out_recv=out_recv,
)
nursery.start_soon(protocol.recv_task)
nursery.start_soon(protocol.send_task)
nursery.start_soon(protocol.mux_task)
nursery.start_soon(protocol.demux_task)
nursery.start_soon(protocol.keep_alive_task)
async def channel(id):
async with protocol.channel(id=id) as channel:
for letter in ascii_letters:
message = (b'%d' % id) + letter.encode()
await channel.send(message)
await sleep(randint(0, 5))
nursery.start_soon(lambda: channel(id=1))
nursery.start_soon(lambda: channel(id=2))
nursery.start_soon(lambda: channel(id=3))
await sleep(60) # run for a minute
async def main():
left_side_stream, right_side_stream = memory_stream_pair()
async with open_nursery() as nursery:
nursery.start_soon(left_side_protocol, nursery, left_side_stream)
nursery.start_soon(right_side_protocol, nursery, right_side_stream)
run(main)
@decentral1se
Copy link
Author

decentral1se commented Jan 17, 2020

Example output of this script is:

P2 -> mux 1a for C1
P1 -> mux 2a for C2
P1 -> sent 2a
P2 -> sent 1a
P1::C1 -> received a
P2::C2 -> received a
P2 -> demux a for C2
P1 -> demux a for C1
P2 -> mux 1b for C1
P2 -> sent 1b
P2 -> mux 3a for C3
P2 -> mux 2a for C2
P2 -> sent 3a
P1::C1 -> received b
P1 -> demux b for C1
P2 -> sent 2a
P1 -> demux a for C3
P1::C3 -> received a
P1::C2 -> received a
P1 -> demux a for C2
P2 -> sent KA
P2 -> mux 1c for C1
P2 -> sent 1c
P2 -> mux 3b for C3
P2 -> mux 2b for C2
P2 -> sent 3b
P1::C1 -> received c
P1 -> demux c for C1
P2 -> sent 2b
P1 -> demux b for C3
P1::C3 -> received b
P1 -> demux b for C2
P1::C2 -> received b
P2 -> mux 1d for C1
P2 -> sent 1d
P1 -> demux d for C1

@smurfix
Copy link

smurfix commented Jan 17, 2020

What's the point of all these with move_on_after(1): statements? IMHO they're either superfluous or may cause data loss.

@decentral1se
Copy link
Author

@smurfix, aha, you're right, I can drop them altogether. This is mainly due to my dodgy understanding of blocking/non-blocking. I thought I needed to include these to avoid stalling the program when there is nothing coming or going on the channel/stream. Will update the gist. Thanks!

@decentral1se
Copy link
Author

decentral1se commented May 17, 2020

Since realised I can replace all this channel noise with a trio.Lock!

The stream itself is the queue that can handle the comms.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment