Skip to content

Instantly share code, notes, and snippets.

@mazz
Created November 20, 2016 17:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mazz/f50affc65bbf04a6140e257354b16898 to your computer and use it in GitHub Desktop.
Save mazz/f50affc65bbf04a6140e257354b16898 to your computer and use it in GitHub Desktop.
zeromq controller
from __future__ import print_function
import uuid
import time
import json
from six.moves import xrange
import zmq
from actors.worker import Worker
from multiprocessing import Process, Event
from pushpull.models import DBSession
from pushpull.models import Number
import transaction
# from sqlalchemy.orm import scoped_session
# from sqlalchemy.orm import sessionmaker
# from zope.sqlalchemy import ZopeTransactionExtension
# from sqlalchemy.ext.declarative import declarative_base
#
# DBSession = scoped_session(sessionmaker(extension=ZopeTransactionExtension()))
# Base = declarative_base()
import logging
FORMAT = '%(asctime)s %(levelname)-5.5s [%(name)s][%(threadName)s] %(message)s'
logging.basicConfig(format=FORMAT)
LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO)
class Job(object):
def __init__(self, work):
self.id = uuid.uuid4().hex
self.work = work
class Controller(object):
"""
Manage distribution of jobs to workers and collation of results.
"""
CLIENT_PORT = 5754
CONTROL_PORT = 5755
def __init__(self, stop_event, port=CONTROL_PORT, client_port=CLIENT_PORT):
self.logger = logging.getLogger(__name__)
self.stop_event = stop_event
self.context = zmq.Context()
# incoming pull socket coming from client
self.socket_client = self.context.socket(zmq.PULL)
self.socket_client.bind('tcp://*:{0}'.format(client_port))
self.result_poller = zmq.Poller()
self.result_poller.register(self.socket_client, zmq.POLLIN)
self.socket = self.context.socket(zmq.ROUTER)
self.socket.bind('tcp://*:{0}'.format(port))
self.workers = {}
# We won't assign more than 50 jobs to a worker at a time; this ensures
# reasonable memory usage, and less shuffling when a worker dies.
self.max_jobs_per_worker = 50
# When/if a client disconnects we'll put any unfinished work in here
self._work_to_requeue = []
def run(self):
while True:
job = None
next_worker_id = None
while next_worker_id is None:
# First check if there are any worker messages to process. We
# do this while checking for the next available worker so that
# if it takes a while to find one we're still processing
# incoming messages.
while self.socket.poll(0):
# Note that we're using recv_multipart() here, this is a
# special method on the ROUTER socket that includes the
# id of the sender. It doesn't handle the json decoding
# automatically though so we have to do that ourselves.
worker_id, message = self.socket.recv_multipart()
message = json.loads(message.decode('utf8'))
self._handle_worker_message(worker_id, message)
# If there are no available workers (they all have 50 or
# more jobs already) sleep for half a second.
next_worker_id = self._get_next_worker_id()
if next_worker_id is None:
time.sleep(0.5)
if self._work_to_requeue:
job = self._work_to_requeue.pop()
LOG.debug('~~using requeued job~~: {0}'.format(repr(job.id)))
else:
socks = dict(self.result_poller.poll(0))
# did we receive a piece of work to do from
# the client?
if socks.get(self.socket_client) == zmq.POLLIN:
result = self.socket_client.recv_json(zmq.DONTWAIT)
job = Job(result)
if job is not None:
self._send_job(job, next_worker_id)
if self.stop_event.is_set():
break
self.stop_event.set()
def _send_job(self, job, next_worker_id):
# We've got a Job and an available worker_id, all we need to do
# is send it. Note that we're now using send_multipart(), the
# counterpart to recv_multipart(), to tell the ROUTER where our
# message goes.
self.logger.info('sending job %s to worker %s', job.id,
next_worker_id)
self.workers[next_worker_id][job.id] = job
self.socket.send_multipart(
[next_worker_id, json.dumps((job.id, job.work)).encode('utf8')])
def _get_next_worker_id(self):
"""Return the id of the next worker available to process work. Note
that this will return None if no clients are available.
"""
# It isn't strictly necessary since we're limiting the amount of work
# we assign, but just to demonstrate that we're doing our own load
# balancing we'll find the worker with the least work
if self.workers:
worker_id, work = sorted(self.workers.items(),
key=lambda x: len(x[1]))[0]
if len(work) < self.max_jobs_per_worker:
return worker_id
# No worker is available. Our caller will have to handle this.
return None
def _handle_worker_message(self, worker_id, message):
"""Handle a message from the worker identified by worker_id.
{'message': 'connect'}
{'message': 'disconnect'}
{'message': 'job_done', 'job_id': 'xxx', 'result': 'yyy'}
"""
if message['message'] == 'connect':
assert worker_id not in self.workers
self.workers[worker_id] = {}
self.logger.info('[%s]: connect', worker_id)
elif message['message'] == 'disconnect':
# Remove the worker so no more work gets added, and put any
# remaining work into _work_to_requeue
remaining_work = self.workers.pop(worker_id)
self._work_to_requeue.extend(remaining_work.values())
self.logger.info('[%s]: disconnect, %s jobs requeued', worker_id,
len(remaining_work))
elif message['message'] == 'job_done':
result = message['result']
job = self.workers[worker_id].pop(message['job_id'])
self._process_results(worker_id, job, result)
else:
raise Exception('unknown message: %s' % message['message'])
def _process_results(self, worker_id, job, result):
xact = transaction.begin()
t = Number(result)
self.logger.info('number: {0}'.format(repr(t)))
DBSession.add(t)
xact.commit()
self.logger.info('[{0}]: finished {1}, result: {2}'.format(repr(worker_id), repr(job.id), repr(result)))
def run_worker(event):
logging.basicConfig(level=logging.INFO)
worker = Worker(event)
worker.run()
def run_controller(event):
logging.basicConfig(level=logging.INFO)
Controller(event).run()
def run():
stop_event = Event()
processes = []
processes.append(Process(target=run_controller, args=(stop_event,)))
# Start a few worker processes
for i in range(10):
processes.append(Process(target=run_worker, args=(stop_event,)))
# To test out our disconnect messaging we'll also start one more worker
# process with a different event that we'll stop shortly after starting.
another_stop_event = Event()
processes.append(Process(target=run_worker, args=(another_stop_event,)))
for p in processes:
p.start()
try:
time.sleep(5)
another_stop_event.set()
# The controller will set the stop event when it's finished, just
# idle until then
while not stop_event.is_set():
time.sleep(1)
except KeyboardInterrupt:
stop_event.set()
another_stop_event.set()
print('waiting for processes to die...')
for p in processes:
p.join()
print('all done')
if __name__ == '__main__':
run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment