Skip to content

Instantly share code, notes, and snippets.

@chancyk
Last active May 24, 2020 06:39
Show Gist options
  • Save chancyk/1d0f331e3bb73d4b322e6049e22fbb45 to your computer and use it in GitHub Desktop.
Save chancyk/1d0f331e3bb73d4b322e6049e22fbb45 to your computer and use it in GitHub Desktop.
PyZMQ RPC
"""
LICENSE: MIT
This is an experiment. Use at your own risk.
"""
import sys
import time
import zlib
import pickle
import random
from traceback import format_tb
from functools import partial
from threading import Thread
from multiprocessing import Process
import zmq
# Known name mappings to an address. This is used by RPC_Client.__init__ to
# establish the connection to the server by the name. This is basically
# a dumb DNS.
SERVERS = {
'test.server': 'tcp://127.0.0.1:8000'
}
# The number of times the client will retry if ZeroMQ starts raising
# `zmq.error.Again`. Each time the `CLIENT_RETRY_DELAY` will be multiplied
# by the number of retries.
NUM_CLIENT_RETRIES = 10
# Duration to `time.sleep` the client in seconds when a send fails with
# `zmq.error.Again`. 200ms seems to be a good amount of time to wait for ZeroMQ
# to catch up while minimizing the number of retries. It's probably
# worth retesting this if performance becomes critical.
CLIENT_RETRY_DELAY = 0.2
class RPC_Exception(Exception):
"""Super class of other RPC server exceptions."""
pass
class RPC_ServerError(RPC_Exception):
"""An exception raised by the RPC server."""
pass
class RPC_FunctionFailed(RPC_Exception):
"""The RPC function called raised an exception."""
pass
class RPC_UnknownFunction(RPC_Exception):
"""The function was not exposed to the server."""
pass
class RPC_FailedToRetry(RPC_Exception):
"""Raised when ZMQ has failed to send or recv a message more than 3 times."""
pass
def maybe_raise_exception(result):
"""Check the returned result for an exception
returned from the server.
"""
if isinstance(result, RPC_FunctionFailed):
exc_type, exc_val, exc_tb = result.args
hr = '-'*100
print('\n%s\n' % hr)
print(' * RPC Function Failed:\n\n')
print(''.join(exc_tb))
print('')
print(
" * The remote function raised the following exception:\n\n "\
"%s: %s\n%s\n\n" % (exc_type, exc_val, hr)
)
raise RPC_Exception("An exception was raised by the RPC function.")
elif isinstance(result, Exception):
raise Exception(str(result))
def serialize(data):
return pickle.dumps(data)
def deserialize(data):
return pickle.loads(data)
def compress(data):
return zlib.compress(data)
def decompress(data):
return zlib.decompress(data)
def encode_msg(data):
msg = compress(serialize(data))
return msg
def decode_msg(msg):
data = deserialize(decompress(msg))
maybe_raise_exception(data)
return data
class RPC_Client():
def __init__(self, server, debug=False):
self._context = None
self._socket = None
self._server = None
self._debug = debug
self._init_socket(server)
def _init_socket(self, server):
self._server = server
self._address = SERVERS[self._server]
self._context = zmq.Context()
functions = self._call_remote_fn('fn_list')
for fn_name in functions:
partial_fn = partial(self._call_remote_fn, fn_name)
partial_fn.__name__ = fn_name
self.__dict__[fn_name] = \
partial_fn
def _send(self, socket, data):
zmq_send(socket, data, warn_on_retry=self._debug)
def _recv(self, socket):
msg = socket.recv()
return decode_msg(msg)
def _request(self, fn_call):
socket = self._context.socket(zmq.REQ)
identity = u"rpc-client-{}"\
.format(random.randint(10000, 99999))\
.encode('ascii')
socket.setsockopt(zmq.IDENTITY, identity)
socket.connect(self._address)
try:
self._send(socket, fn_call)
response = self._recv(socket)
except:
raise
finally:
socket.close()
return response
def _call_remote_fn(self, fn_name, *args, **kwargs):
fn_call = (fn_name, args, kwargs)
return self._request(fn_call)
def __getattr__(self, attr):
if not attr in self.__dict__:
partial_fn = partial(self._call_remote_fn, attr)
partial_fn.__name__ = attr
return partial_fn
else:
return self.__dict__[attr]
def zmq_send(socket, msg, warn_on_retry=False):
retries = 0
encoded_msg = encode_msg(msg)
while True:
try:
socket.send(encoded_msg, zmq.NOBLOCK)
except zmq.error.Again as e:
retries += 1
if retries > NUM_CLIENT_RETRIES:
raise RPC_FailedToRetry(
"PyZMQ send raised `zmq.error.Again` more than %s times. "\
"You may be trying to send too many messages at once."\
% NUM_CLIENT_RETRIES
)
else:
delay = CLIENT_RETRY_DELAY * retries
if warn_on_retry:
print(
"Warning: ZMQ failed to send. Retrying after %sms..."\
% int(delay * 1000)
)
time.sleep(delay)
else:
break
def worker_routine(identity, functions, backend_address):
context = zmq.Context()
socket = context.socket(zmq.REQ)
identity = u"rpc-worker-{}".format(identity).encode('ascii')
socket.setsockopt(zmq.IDENTITY, identity)
socket.connect(backend_address)
socket.send(b"READY")
for client, fn_call in poll_socket(socket, timeout=100):
dispatch(client, socket, functions, fn_call)
def poll_socket(socket, timeout):
poller = zmq.Poller()
poller.register(socket, zmq.POLLIN)
while True:
sockets = dict(poller.poll(timeout))
if sockets.get(socket) == zmq.POLLIN:
client, empty, msg = socket.recv_multipart()
data = decode_msg(msg)
else:
identity = None
data = None
if not (data is None):
yield client, data
def dispatch(client, socket, functions, fn_call):
"""Call the function by name from the list of functions registered via
the `expose`.
"""
fn_name, args, kwargs = fn_call
if fn_name == 'fn_list':
result = list(functions.keys())
else:
if fn_name in functions:
rpc_fn = functions[fn_name]
result = execute_rpc_fn(rpc_fn, args, kwargs)
else:
result = RPC_UnknownFunction(
"The function %s does not exist on the server. "\
"Use `RPC_Server.expose()` to register a function."
% fn_name
)
socket.send(client, zmq.SNDMORE)
socket.send(b"", zmq.SNDMORE)
zmq_send(socket, result)
def execute_rpc_fn(rpc_fn, args, kwargs):
try:
result = rpc_fn(*args, **kwargs)
except Exception:
exc_type, exc_val, exc_tb = sys.exc_info()
result = RPC_FunctionFailed(
str(exc_type), str(exc_val), format_tb(exc_tb)
)
return result
class RPC_Server():
def __init__(self, server_name, debug=False):
self._server_name = server_name
self._debug = debug
self._frontend_address = SERVERS[server_name]
self._backend_address = b"inproc://backend"
# self._backend_address = b"tcp://127.0.0.1:8001"
self._functions = {}
if not server_name in SERVERS:
raise RPC_ServerError(
"Server <%s> is not in the SERVERS dictionary. Add it in "\
"zmq_rpc.py so that clients can find the server by name."\
% server
)
def expose(self, fn):
self._functions[fn.__name__] = fn
return fn
def _start_workers(self, num_workers):
for identity in range(num_workers):
worker = Thread(
target=worker_routine,
args=(
identity,
self._functions,
self._backend_address
)
)
worker.daemon = True
worker.start()
def _proxy_loop(self, frontend, backend):
workers = []
poller = zmq.Poller()
# Only poll for requests from backend until workers are available
poller.register(backend, zmq.POLLIN)
while True:
# Timeouts below 50ms seem to poll faster than messages can be sent.
# We have to have a timeout in order to Ctrl+C on Windows.
sockets = dict(poller.poll(timeout=50))
if backend in sockets:
# Handle worker activity on the backend
request = backend.recv_multipart()
worker, empty, client = request[:3]
if not workers:
# Poll for clients now that a worker is available
poller.register(frontend, zmq.POLLIN)
workers.append(worker)
if client != b"READY" and len(request) > 3:
# If client reply, send rest back to frontend
empty, reply = request[3:]
frontend.send_multipart([client, b"", reply])
if frontend in sockets:
# Get next client request, route to last-used worker
client, empty, request = frontend.recv_multipart()
worker = workers.pop(0)
backend.send_multipart([worker, b"", client, b"", request])
if not workers:
# Don't poll clients if no workers are available
poller.unregister(frontend)
return empty_polls
def start(self, num_workers=4):
# Prepare our context and sockets
context = zmq.Context.instance()
# Socket to talk to clients
frontend = context.socket(zmq.ROUTER)
print(self._frontend_address)
frontend.bind(self._frontend_address)
# Socket to talk to workers
backend = context.socket(zmq.ROUTER)
backend.bind(self._backend_address)
# Launch pool of worker threads
self._start_workers(num_workers)
print(" * Server %s listening at: %s" \
% (self._server_name, self._frontend_address))
print(" - Press <Ctrl+C> to halt the server..")
try:
self._proxy_loop(frontend, backend)
except KeyboardInterrupt:
print(" * Halting..")
finally:
frontend.close()
backend.close()
context.term()
if __name__ == '__main__':
rpc_server = RPC_Server('test.server')
@rpc_server.expose
def add_one(x):
return x + 1
def start_rpc_server(rpc_server):
rpc_server.start()
def start_client():
rpc_client = RPC_Client('test.server')
assert rpc_client.add_one(1) == 2
print("add_one(1) == 2")
t1 = Thread(target=start_rpc_server, args=(rpc_server,))
t1.daemon = True
t1.start()
t2 = Thread(target=start_client)
t2.daemon = True
t2.start()
t2.join()
@wamdam
Copy link

wamdam commented May 24, 2020

Thank you @chancyk!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment