Skip to content

Instantly share code, notes, and snippets.

@chancyk
Last active May 24, 2020 06:39
Show Gist options
  • Save chancyk/1d0f331e3bb73d4b322e6049e22fbb45 to your computer and use it in GitHub Desktop.
Save chancyk/1d0f331e3bb73d4b322e6049e22fbb45 to your computer and use it in GitHub Desktop.
PyZMQ RPC
"""
LICENSE: MIT
This is an experiment. Use at your own risk.
"""
import sys
import time
import zlib
import pickle
import random
from traceback import format_tb
from functools import partial
from threading import Thread
from multiprocessing import Process
import zmq
# Known name mappings to an address. This is used by RPC_Client.__init__ to
# establish the connection to the server by the name. This is basically
# a dumb DNS.
SERVERS = {
'test.server': 'tcp://127.0.0.1:8000'
}
# The number of times the client will retry if ZeroMQ starts raising
# `zmq.error.Again`. Each time the `CLIENT_RETRY_DELAY` will be multiplied
# by the number of retries.
NUM_CLIENT_RETRIES = 10
# Duration to `time.sleep` the client in seconds when a send fails with
# `zmq.error.Again`. 200ms seems to be a good amount of time to wait for ZeroMQ
# to catch up while minimizing the number of retries. It's probably
# worth retesting this if performance becomes critical.
CLIENT_RETRY_DELAY = 0.2
class RPC_Exception(Exception):
"""Super class of other RPC server exceptions."""
pass
class RPC_ServerError(RPC_Exception):
"""An exception raised by the RPC server."""
pass
class RPC_FunctionFailed(RPC_Exception):
"""The RPC function called raised an exception."""
pass
class RPC_UnknownFunction(RPC_Exception):
"""The function was not exposed to the server."""
pass
class RPC_FailedToRetry(RPC_Exception):
"""Raised when ZMQ has failed to send or recv a message more than 3 times."""
pass
def maybe_raise_exception(result):
"""Check the returned result for an exception
returned from the server.
"""
if isinstance(result, RPC_FunctionFailed):
exc_type, exc_val, exc_tb = result.args
hr = '-'*100
print('\n%s\n' % hr)
print(' * RPC Function Failed:\n\n')
print(''.join(exc_tb))
print('')
print(
" * The remote function raised the following exception:\n\n "\
"%s: %s\n%s\n\n" % (exc_type, exc_val, hr)
)
raise RPC_Exception("An exception was raised by the RPC function.")
elif isinstance(result, Exception):
raise Exception(str(result))
def serialize(data):
return pickle.dumps(data)
def deserialize(data):
return pickle.loads(data)
def compress(data):
return zlib.compress(data)
def decompress(data):
return zlib.decompress(data)
def encode_msg(data):
msg = compress(serialize(data))
return msg
def decode_msg(msg):
data = deserialize(decompress(msg))
maybe_raise_exception(data)
return data
class RPC_Client():
def __init__(self, server, debug=False):
self._context = None
self._socket = None
self._server = None
self._debug = debug
self._init_socket(server)
def _init_socket(self, server):
self._server = server
self._address = SERVERS[self._server]
self._context = zmq.Context()
functions = self._call_remote_fn('fn_list')
for fn_name in functions:
partial_fn = partial(self._call_remote_fn, fn_name)
partial_fn.__name__ = fn_name
self.__dict__[fn_name] = \
partial_fn
def _send(self, socket, data):
zmq_send(socket, data, warn_on_retry=self._debug)
def _recv(self, socket):
msg = socket.recv()
return decode_msg(msg)
def _request(self, fn_call):
socket = self._context.socket(zmq.REQ)
identity = u"rpc-client-{}"\
.format(random.randint(10000, 99999))\
.encode('ascii')
socket.setsockopt(zmq.IDENTITY, identity)
socket.connect(self._address)
try:
self._send(socket, fn_call)
response = self._recv(socket)
except:
raise
finally:
socket.close()
return response
def _call_remote_fn(self, fn_name, *args, **kwargs):
fn_call = (fn_name, args, kwargs)
return self._request(fn_call)
def __getattr__(self, attr):
if not attr in self.__dict__:
partial_fn = partial(self._call_remote_fn, attr)
partial_fn.__name__ = attr
return partial_fn
else:
return self.__dict__[attr]
def zmq_send(socket, msg, warn_on_retry=False):
retries = 0
encoded_msg = encode_msg(msg)
while True:
try:
socket.send(encoded_msg, zmq.NOBLOCK)
except zmq.error.Again as e:
retries += 1
if retries > NUM_CLIENT_RETRIES:
raise RPC_FailedToRetry(
"PyZMQ send raised `zmq.error.Again` more than %s times. "\
"You may be trying to send too many messages at once."\
% NUM_CLIENT_RETRIES
)
else:
delay = CLIENT_RETRY_DELAY * retries
if warn_on_retry:
print(
"Warning: ZMQ failed to send. Retrying after %sms..."\
% int(delay * 1000)
)
time.sleep(delay)
else:
break
def worker_routine(identity, functions, backend_address):
context = zmq.Context()
socket = context.socket(zmq.REQ)
identity = u"rpc-worker-{}".format(identity).encode('ascii')
socket.setsockopt(zmq.IDENTITY, identity)
socket.connect(backend_address)
socket.send(b"READY")
for client, fn_call in poll_socket(socket, timeout=100):
dispatch(client, socket, functions, fn_call)
def poll_socket(socket, timeout):
poller = zmq.Poller()
poller.register(socket, zmq.POLLIN)
while True:
sockets = dict(poller.poll(timeout))
if sockets.get(socket) == zmq.POLLIN:
client, empty, msg = socket.recv_multipart()
data = decode_msg(msg)
else:
identity = None
data = None
if not (data is None):
yield client, data
def dispatch(client, socket, functions, fn_call):
"""Call the function by name from the list of functions registered via
the `expose`.
"""
fn_name, args, kwargs = fn_call
if fn_name == 'fn_list':
result = list(functions.keys())
else:
if fn_name in functions:
rpc_fn = functions[fn_name]
result = execute_rpc_fn(rpc_fn, args, kwargs)
else:
result = RPC_UnknownFunction(
"The function %s does not exist on the server. "\
"Use `RPC_Server.expose()` to register a function."
% fn_name
)
socket.send(client, zmq.SNDMORE)
socket.send(b"", zmq.SNDMORE)
zmq_send(socket, result)
def execute_rpc_fn(rpc_fn, args, kwargs):
try:
result = rpc_fn(*args, **kwargs)
except Exception:
exc_type, exc_val, exc_tb = sys.exc_info()
result = RPC_FunctionFailed(
str(exc_type), str(exc_val), format_tb(exc_tb)
)
return result
class RPC_Server():
def __init__(self, server_name, debug=False):
self._server_name = server_name
self._debug = debug
self._frontend_address = SERVERS[server_name]
self._backend_address = b"inproc://backend"
# self._backend_address = b"tcp://127.0.0.1:8001"
self._functions = {}
if not server_name in SERVERS:
raise RPC_ServerError(
"Server <%s> is not in the SERVERS dictionary. Add it in "\
"zmq_rpc.py so that clients can find the server by name."\
% server
)
def expose(self, fn):
self._functions[fn.__name__] = fn
return fn
def _start_workers(self, num_workers):
for identity in range(num_workers):
worker = Thread(
target=worker_routine,
args=(
identity,
self._functions,
self._backend_address
)
)
worker.daemon = True
worker.start()
def _proxy_loop(self, frontend, backend):
workers = []
poller = zmq.Poller()
# Only poll for requests from backend until workers are available
poller.register(backend, zmq.POLLIN)
while True:
# Timeouts below 50ms seem to poll faster than messages can be sent.
# We have to have a timeout in order to Ctrl+C on Windows.
sockets = dict(poller.poll(timeout=50))
if backend in sockets:
# Handle worker activity on the backend
request = backend.recv_multipart()
worker, empty, client = request[:3]
if not workers:
# Poll for clients now that a worker is available
poller.register(frontend, zmq.POLLIN)
workers.append(worker)
if client != b"READY" and len(request) > 3:
# If client reply, send rest back to frontend
empty, reply = request[3:]
frontend.send_multipart([client, b"", reply])
if frontend in sockets:
# Get next client request, route to last-used worker
client, empty, request = frontend.recv_multipart()
worker = workers.pop(0)
backend.send_multipart([worker, b"", client, b"", request])
if not workers:
# Don't poll clients if no workers are available
poller.unregister(frontend)
return empty_polls
def start(self, num_workers=4):
# Prepare our context and sockets
context = zmq.Context.instance()
# Socket to talk to clients
frontend = context.socket(zmq.ROUTER)
print(self._frontend_address)
frontend.bind(self._frontend_address)
# Socket to talk to workers
backend = context.socket(zmq.ROUTER)
backend.bind(self._backend_address)
# Launch pool of worker threads
self._start_workers(num_workers)
print(" * Server %s listening at: %s" \
% (self._server_name, self._frontend_address))
print(" - Press <Ctrl+C> to halt the server..")
try:
self._proxy_loop(frontend, backend)
except KeyboardInterrupt:
print(" * Halting..")
finally:
frontend.close()
backend.close()
context.term()
if __name__ == '__main__':
rpc_server = RPC_Server('test.server')
@rpc_server.expose
def add_one(x):
return x + 1
def start_rpc_server(rpc_server):
rpc_server.start()
def start_client():
rpc_client = RPC_Client('test.server')
assert rpc_client.add_one(1) == 2
print("add_one(1) == 2")
t1 = Thread(target=start_rpc_server, args=(rpc_server,))
t1.daemon = True
t1.start()
t2 = Thread(target=start_client)
t2.daemon = True
t2.start()
t2.join()
@chancyk
Copy link
Author

chancyk commented May 22, 2020

Hi @wamdam! I've made it MIT licensed with the intent that you can use it however you wish. This was just an experimet to understand PyZMQ on Windows, so it hasn't been tested much, was written many years ago, and was never used in any production capacity. I hope it's useful to you in some capacity though!

@wamdam
Copy link

wamdam commented May 24, 2020

Thank you @chancyk!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment