Skip to content

Instantly share code, notes, and snippets.

@minrk
Forked from costrouc/client.py
Created November 9, 2017 11:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save minrk/61841692ab674c4bd645c1f2654fcfd4 to your computer and use it in GitHub Desktop.
Save minrk/61841692ab674c4bd645c1f2654fcfd4 to your computer and use it in GitHub Desktop.
Asyncio Majordomo Protocol (18/MDP)
import asyncio
import logging
import zmq
import zmq.asyncio
from scheduler import SchedulerCode
class Client:
DEFAULT_PROTOCOL = "tcp"
DEFAULT_PORT = 8000
DEFAULT_HOSTNAME = '0.0.0.0'
def __init__(self, protocol=DEFAULT_PROTOCOL, port=DEFAULT_PORT, hostname=DEFAULT_HOSTNAME, loop=None):
self.loop = loop or asyncio.get_event_loop()
self.logger = logging.getLogger('mdp.client')
self.context = zmq.asyncio.Context()
self.socket = self.context.socket(zmq.DEALER)
self.uri = f'{protocol}://{hostname}:{port}'
self.logger.info(f'Connecting ZMQ socket to {self.uri}')
self.socket.connect(self.uri)
async def submit(self, service, message):
self.logger.debug(f'sending message to service {service}')
await self.socket.send_multipart([
b'', SchedulerCode.CLIENT, service, *message
])
async def get(self):
multipart_message = await self.socket.recv_multipart()
_, _, service, *message = multipart_message
self.logger.debug(f'recieving message from service {service}')
return multipart_message[2], multipart_message[3:]
def disconnect(self):
self.socket.close()
import logging
from multiprocessing import Process, Event
import asyncio
import zmq.asyncio
from client import Client
from worker import Worker
from scheduler import Scheduler
def init_logging():
logging.basicConfig(level=logging.INFO)
def init_event_loop():
loop = zmq.asyncio.ZMQEventLoop()
asyncio.set_event_loop(loop)
return loop
def init_client(stop_event):
init_logging()
async def create_work(loop):
client = Client(loop=loop)
N = 10
MSG_SIZE = 1_000
for i in range(N):
if i % 1 == 0:
print(f'[ Client ] {i+1} jobs submitted')
await client.submit(b'hello.world', [b'o'*MSG_SIZE])
for i in range(N):
service, message = await client.get()
if i % 1 == 0:
print(f'[ Client ] {i+1} jobs completed')
print('[ Client ] === DONE ===')
client.disconnect()
loop = init_event_loop()
loop.run_until_complete(create_work(loop))
def init_scheduler(stop_event):
init_logging()
scheduler = Scheduler(stop_event, loop=init_event_loop())
scheduler.run()
def init_worker(stop_event):
counter = 0
async def hello_world_worker(*message):
nonlocal counter
counter += 1
print(f'[ Worker ] processing message {counter}')
return (b'1', b'2', b'3')
init_logging()
loop = init_event_loop()
worker = Worker(stop_event, loop=loop)
loop.run_until_complete(worker.run(b'hello.world', hello_world_worker))
if __name__ == "__main__":
NUM_CLIENTS = 1
NUM_WORKERS = 1
stop_event = Event()
worker_processes = [Process(target=init_worker, args=(stop_event,)) for _ in range(NUM_WORKERS)]
for worker in worker_processes:
worker.start()
scheduler_process = Process(target=init_scheduler, args=(stop_event,))
scheduler_process.start()
client_processes = [Process(target=init_client, args=(stop_event,)) for _ in range(NUM_CLIENTS)]
for client in client_processes:
client.start()
import asyncio
import concurrent
import collections
import logging
import datetime as dt
import uuid
import zmq
import zmq.asyncio
class Message:
def __init__(self, client_id, message):
self.date_added = dt.datetime.utcnow()
self.client_id = client_id
self.message = message
class SchedulerCode:
WORKER = b"MDPW01"
CLIENT = b"MDPC01"
READY = bytes([1])
REQUEST = bytes([2])
REPLY = bytes([3])
HEARTBEAT = bytes([4])
DISCONNECT = bytes([5])
class Scheduler:
DEFAULT_PROTOCOL = "tcp"
DEFAULT_PORT = 8000
DEFAULT_HOSTNAME = '0.0.0.0'
def __init__(self, stop_event, protocol=DEFAULT_PROTOCOL, port=DEFAULT_PORT, hostname=DEFAULT_HOSTNAME, loop=None):
self.stop_event = stop_event
self.loop = loop or asyncio.get_event_loop()
self.logger = logging.getLogger('mdp.scheduler')
self.context = zmq.asyncio.Context()
self.socket = self.context.socket(zmq.ROUTER)
self.logger.info(f'Binding ZMQ socket client to {protocol}://{hostname}:{port}')
self.socket.bind(f'{protocol}://{hostname}:{port}')
self.messages = {}
self.workers = {}
self.services = collections.defaultdict(lambda: {'workers': set(), 'queue': asyncio.Queue(), 'task': None})
async def _handle_client_message(self, client_id, multipart_message):
service, *message_data = multipart_message
self.logger.debug(f'adding client {client_id} message for service {service} to queue')
message_uuid = uuid.uuid4().bytes
message = Message(client_id, message_data)
self.messages[message_uuid] = message
await self.services[service]['queue'].put(message_uuid)
async def _next_worker(self, service):
import random
return random.sample(service['workers'], 1)[0] # scheduling logic
# return service['workers'][0]
async def _handle_service_queue(self, service):
counter = 0
try:
while True:
message_uuid = await service['queue'].get()
message = self.messages[message_uuid]
worker_id = await self._next_worker(service)
counter += 1
print(f'[Scheduler] Count {counter} sent to worker {worker_id}')
self.workers[worker_id]['messages'].add(message_uuid)
await self.socket.send_multipart([
worker_id, b'', SchedulerCode.WORKER, SchedulerCode.REQUEST,
message_uuid, b'', *message.message
])
service['queue'].task_done()
except asyncio.CancelledError:
self.logger.info('stopping worker for service')
async def _handle_worker_message(self, worker_id, multipart_message):
message_type = multipart_message[0]
if message_type == SchedulerCode.READY:
service_name = multipart_message[1]
service = self.services[service_name]
self.logger.info(f'adding worker {worker_id} for service {service_name}')
self.workers[worker_id] = {'service': service_name, 'messages': set()}
service['workers'].add(worker_id)
if len(service['workers']) == 1:
service['task'] = asyncio.ensure_future(self._handle_service_queue(service))
elif message_type == SchedulerCode.REPLY:
message_uuid = multipart_message[1]
self.workers[worker_id]['messages'].remove(message_uuid)
message = self.messages.pop(message_uuid)
self.logger.debug(f'sending client {message.client_id} message response from worker {worker_id}')
print(f'[Scheduler] message done from worker {worker_id} for client {message.client_id}')
await self.socket.send_multipart([
message.client_id, b'', SchedulerCode.CLIENT,
self.workers[worker_id]['service'], *multipart_message[3:]
])
elif message_type == SchedulerCode.HEARTBEAT:
self.logger.debug('responding with heartbeat')
await self.socket.send_multipart([
worker_id, b'', SchedulerCode.WORKER, SchedulerCode.HEARTBEAT
])
elif message_type == SchedulerCode.DISCONNECT:
if worker_id in self.workers:
worker = self.workers[worker_id]
service = self.services[worker['service']]
if len(service['workers']) == 1: # last worker
self.logger.info(f'canceling {worker["service"]} service queue task')
service['task'].cancel()
try:
await service['task']
except concurrent.futures.CancelledError:
pass
service['task'] = None
self.logger.info(f'removing worker {worker_id} for service {worker["service"]} - rescheduling {len(worker["messages"])} messages')
service['workers'].remove(worker_id)
for message in worker['messages']:
await service['queue'].put(message)
self.workers.pop(worker_id)
async def on_recv_message(self):
while not self.stop_event.is_set():
multipart_message = await self.socket.recv_multipart()
client_id, _1, message_sender, *message = multipart_message
if message_sender == SchedulerCode.WORKER:
await self._handle_worker_message(client_id, message)
elif message_sender == SchedulerCode.CLIENT:
await self._handle_client_message(client_id, message)
else:
raise ValueError()
def run(self):
self.loop.run_until_complete(self.on_recv_message())
def disconnect(self):
self.stop_event.set()
self.socket.close()
import asyncio
import logging
import datetime as dt
import zmq
import zmq.asyncio
from scheduler import SchedulerCode
class Worker:
DEFAULT_PROTOCOL = "tcp"
DEFAULT_PORT = 8000
DEFAULT_HOSTNAME = '0.0.0.0'
def __init__(self, stop_event,
heartbeat_interval=2, heartbeat_timeout=10,
protocol=DEFAULT_PROTOCOL, port=DEFAULT_PORT, hostname=DEFAULT_HOSTNAME, loop=None):
self.stop_event = stop_event
self.loop = loop or asyncio.get_event_loop()
self.logger = logging.getLogger('mdp.worker')
self.context = zmq.asyncio.Context()
self.socket = self.context.socket(zmq.DEALER)
self.uri = f'{protocol}://{hostname}:{port}'
self.heartbeat_interval = heartbeat_interval
self.heartbeat_timeout = heartbeat_timeout
self.heartbeat_last_response = dt.datetime.utcnow()
self.service = None
self.service_handler = None
self.queued_messages = asyncio.Queue()
async def _handle_send_heartbeat(self):
while not self.stop_event.is_set():
if not self.socket.closed:
self.logger.debug('sending heartbeat')
await self.socket.send_multipart([
b'', SchedulerCode.WORKER, SchedulerCode.HEARTBEAT
])
await asyncio.sleep(self.heartbeat_interval)
async def _handle_check_heartbeat(self):
while not self.stop_event.is_set():
previous_heartbeat_check = dt.datetime.utcnow()
await asyncio.sleep(self.heartbeat_timeout)
if not self.socket.closed and \
self.heartbeat_last_response < previous_heartbeat_check:
self.logger.info(f'no response from broker in {self.heartbeat_timeout} seconds -- reconnecting')
await self.disconnect()
await self.connect()
async def _handle_queued_messages(self):
counter = 0
while not self.stop_event.is_set():
client_id, message = await self.queued_messages.get()
result = await self.service_handler(*message)
counter += 1
print(f'[ Worker ] Counter {counter:5} completed Queue size: {self.queued_messages.qsize():5}')
await self.socket.send_multipart([
b'', SchedulerCode.WORKER, SchedulerCode.REPLY, client_id, b'', *result
])
self.queued_messages.task_done()
async def _on_recv_message(self):
while not self.stop_event.is_set():
multipart_message = await self.socket.recv_multipart()
message_type = multipart_message[2]
if message_type == SchedulerCode.REQUEST:
_, _, message_type, client_id, _, *message = multipart_message
self.logger.debug(f'broker sent request message')
await self.queued_messages.put((client_id, message))
self.heartbeat_last_response = dt.datetime.utcnow()
elif message_type == SchedulerCode.HEARTBEAT:
self.logger.debug(f'broker response heartbeat')
self.heartbeat_last_response = dt.datetime.utcnow()
elif message_type == SchedulerCode.DISCONNECT:
self.logger.info(f'broker requests disconnect and reconnect')
await self.disconnect()
await self.connect()
else:
raise ValueError() # unknown event type
async def run(self, service, service_handler):
self.service = service
self.service_handler = service_handler
await self.connect()
await asyncio.gather(
self._handle_send_heartbeat(),
self._handle_check_heartbeat(),
self._handle_queued_messages(),
self._on_recv_message()
)
await self.disconnect()
async def connect(self):
self.logger.info(f'connecting ZMQ socket to {self.uri}')
self.socket.connect(self.uri)
await self.socket.send_multipart([
b'', SchedulerCode.WORKER, SchedulerCode.READY, self.service
])
async def disconnect(self):
self.logger.info(f'disconnecting zmq socket from {self.uri}')
await self.socket.send_multipart([
b'', SchedulerCode.WORKER, SchedulerCode.DISCONNECT
])
self.socket.disconnect(self.uri)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment