Skip to content

Instantly share code, notes, and snippets.

@phaustin
Forked from decentral1se/chmxdx_letters.py
Created May 19, 2020 01:58
Show Gist options
  • Save phaustin/ad790f9ecd7be3f97aafbf4d2fd60466 to your computer and use it in GitHub Desktop.
Save phaustin/ad790f9ecd7be3f97aafbf4d2fd60466 to your computer and use it in GitHub Desktop.
chmxdx_letters.py
from contextlib import asynccontextmanager
from string import ascii_letters
from typing import Dict
import attr
from trio import (
BrokenResourceError,
ClosedResourceError,
Lock,
MemoryReceiveChannel,
MemorySendChannel,
open_memory_channel,
open_nursery,
run,
)
from trio.abc import Stream
from trio.testing import memory_stream_pair
@attr.s(auto_attribs=True)
class Channel:
id: int
protocol: "Protocol"
stream: Stream
sendq: MemorySendChannel
recvq: MemoryReceiveChannel
async def handler(self, msg: bytes):
print(f"P{self.protocol.id}::C{self.id} got {msg.decode()}")
async def asend(self, msg: bytes):
await self.protocol.asend(msg)
async def recv_task(self):
while True:
try:
msg = await self.recvq.receive()
await self.handler(msg)
if msg == b"Z":
return
except (ClosedResourceError, BrokenResourceError):
print("P{self.protocol.id} memory channel down, exiting")
return
@attr.s(auto_attribs=True)
class Message:
id: int
val: bytes
@attr.s(auto_attribs=True)
class Protocol:
id: int
lock: Lock
stream: Stream
channels: Dict[int, Channel] = attr.Factory(dict)
BUFFER: int = 15
MESSAGE_LENGTH: int = 2
TIMEOUT: int = 10
def parse(self, msg: bytes):
id, val = bytes(msg[0:1]), bytes(msg[1:2])
return Message(id=int(id), val=val)
async def asend(self, msg: bytes):
async with self.lock:
await self.stream.send_all(msg)
async def recv_task(self):
while True:
try:
resp = await self.stream.receive_some(self.MESSAGE_LENGTH)
msg = self.parse(resp)
if msg.id not in self.channels:
continue
channel = self.channels[msg.id]
await channel.sendq.send(msg.val)
except (ClosedResourceError, BrokenResourceError):
print("P{self.id} stream down, exiting")
return
@asynccontextmanager
async def channel(self, id: int):
sendq, recvq = open_memory_channel(self.BUFFER)
async with open_nursery() as nursery:
channel = Channel(
id=id,
stream=self.stream,
protocol=self,
sendq=sendq,
recvq=recvq,
)
nursery.start_soon(channel.recv_task)
self.channels[id] = channel
yield channel
async def channel(id, protocol):
async with protocol.channel(id=id) as channel:
for letter in ascii_letters:
message = (b"%d" % id) + letter.encode()
await channel.asend(message)
async def left_side_protocol(stream):
async with open_nursery() as nursery:
protocol = Protocol(id=1, stream=stream, lock=Lock())
nursery.start_soon(protocol.recv_task)
async with open_nursery() as n:
for i in range(0, 5):
n.start_soon(lambda: channel(id=i, protocol=protocol))
nursery.cancel_scope.cancel()
async def right_side_protocol(stream):
async with open_nursery() as nursery:
protocol = Protocol(id=2, stream=stream, lock=Lock())
nursery.start_soon(protocol.recv_task)
async with open_nursery() as n:
for i in range(0, 5):
n.start_soon(lambda: channel(id=i, protocol=protocol))
nursery.cancel_scope.cancel()
async def main():
left_side_stream, right_side_stream = memory_stream_pair()
async with open_nursery() as nursery:
nursery.start_soon(left_side_protocol, left_side_stream)
nursery.start_soon(right_side_protocol, right_side_stream)
run(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment