from curio import Queue, CancelledError | |
class Port: | |
def __init__(self): | |
self.chan = None | |
class InputPort(Port): | |
async def recv(self): | |
tok = await self.chan.recv() | |
return tok | |
class OutputPort(Port): | |
async def send(self, val): | |
await self.chan.send((val)) | |
def connect(a, b, name=''): | |
# Connect ports together by instantiating a channel | |
chan = Channel(name) | |
# Check to make sure the ports have not been connected previously to other channels! | |
assert not a.chan, f"Channel {a} has already been connected!" | |
assert not b.chan, f"Channel {a} has already been connected!" | |
# Check to make sure the two ports are of opposite type (input/output) | |
if isinstance(a, InputPort): | |
assert isinstance(b, OutputPort), f"Channel {a} and {b} are both input ports!" | |
# Store the ports this channel is connected to | |
# b ---chan---> a | |
chan.l = b | |
chan.r = a | |
else: | |
assert isinstance(b, InputPort), f"Channel {a} and {b} are both output ports!" | |
# Store the ports this channel is connected to | |
# a ---chan---> b | |
chan.l = a | |
chan.r = b | |
# Now assign the channel to the two ports | |
a.chan = chan | |
b.chan = chan | |
class Process: | |
next_id = 0 | |
non_producer_processes = {} | |
producer_processes = {} | |
def __init__(self, name): | |
self.name = name | |
self.id = Process.next_id | |
Process.next_id += 1 | |
# Keep track of all source processes (join on these at the end), and non-source processes (cancel on these at the end) | |
if isinstance(self, Producer): | |
Process.producer_processes[self.id] = self | |
else: | |
Process.non_producer_processes[self.id] = self | |
# Inject the ports from the annotations on this instance | |
for name, val in self.__annotations__.items(): | |
if issubclass(val, Port): | |
print(f'injecting port({val}) {name} onto {self.__class__}:{self.name}') | |
port = val() | |
setattr(self, name, port) | |
def __str__(self): | |
return f"{self.name}.{self.id}" | |
def __repr__(self): | |
return f"{type(self).__name__}('{self.name}')" | |
def message(self, m): | |
print(f"{self}: {m}") | |
class Producer(Process): | |
# All processes that drive the system (by injecting values in on channels unconditionally) | |
# must subclass this process | |
pass | |
class Source(Producer): | |
R: OutputPort | |
def __init__(self, name, length, srcval): | |
super().__init__(name) | |
self.val = srcval | |
self.length = length | |
async def exec(self): | |
for i in range(self.length): | |
self.message(f"sending {self.val}") | |
await self.R.send(self.val) | |
self.message(f"sent {self.val}") | |
self.message("terminated") | |
class Sink(Process): | |
L: InputPort | |
def __init__(self, name): | |
super().__init__(name) | |
async def exec(self): | |
tok_count = 0 | |
try: | |
while True: | |
tok = await self.L.recv() | |
tok_count += 1 | |
self.message(f"received {tok}") | |
except CancelledError: | |
self.message(f"{tok_count} tokens received") | |
class Buffer(Process): | |
L: InputPort | |
R: OutputPort | |
def __init__(self, name): | |
super().__init__(name) | |
async def exec(self): | |
while True: | |
tok = await self.L.recv() | |
self.message(f"received {tok}") | |
self.message(f"sending {tok}") | |
await self.R.send(tok) | |
class Channel: | |
def __init__(self, name): | |
self.name = name | |
self.q = Queue(maxsize=1) # Max buffering of 1 | |
async def send(self, val): | |
await self.q.put(val) | |
async def recv(self): | |
tok = await self.q.get() | |
await self.q.task_done() | |
return tok | |
async def close(self): | |
await self.q.join() | |
async def run_all(): | |
source_tasks = [] | |
other_tasks = [] | |
for p in Process.producer_processes.values(): | |
source_tasks.append(await spawn(p.exec())) | |
for p in Process.non_producer_processes.values(): | |
other_tasks.append(await spawn(p.exec())) | |
# Now wait for all sources to end | |
for task in source_tasks: | |
await task.join() | |
for task in other_tasks: | |
await task.cancel() | |
from curio import run, spawn | |
async def system(): | |
N = 10 # How many buffers in our linear pipeline | |
# Instantiate the processes | |
src = Source('src1', 10, 1) | |
buf = [Buffer(f'buf[{i}]') for i in range(N)] | |
snk = Sink('snk') | |
# Connect the processes with the channels | |
connect(src.R, buf[0].L) | |
for i in range(1, N): | |
connect(buf[i-1].R, buf[i].L) | |
connect(snk.L, buf[N-1].R) | |
await run_all() | |
if __name__=='__main__': | |
run(system(), with_monitor=True) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment