Last active
May 24, 2020 06:39
-
-
Save chancyk/1d0f331e3bb73d4b322e6049e22fbb45 to your computer and use it in GitHub Desktop.
PyZMQ RPC
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
LICENSE: MIT | |
This is an experiment. Use at your own risk. | |
""" | |
import sys | |
import time | |
import zlib | |
import pickle | |
import random | |
from traceback import format_tb | |
from functools import partial | |
from threading import Thread | |
from multiprocessing import Process | |
import zmq | |
# Known name mappings to an address. This is used by RPC_Client.__init__ to | |
# establish the connection to the server by the name. This is basically | |
# a dumb DNS. | |
SERVERS = { | |
'test.server': 'tcp://127.0.0.1:8000' | |
} | |
# The number of times the client will retry if ZeroMQ starts raising | |
# `zmq.error.Again`. Each time the `CLIENT_RETRY_DELAY` will be multiplied | |
# by the number of retries. | |
NUM_CLIENT_RETRIES = 10 | |
# Duration to `time.sleep` the client in seconds when a send fails with | |
# `zmq.error.Again`. 200ms seems to be a good amount of time to wait for ZeroMQ | |
# to catch up while minimizing the number of retries. It's probably | |
# worth retesting this if performance becomes critical. | |
CLIENT_RETRY_DELAY = 0.2 | |
class RPC_Exception(Exception): | |
"""Super class of other RPC server exceptions.""" | |
pass | |
class RPC_ServerError(RPC_Exception): | |
"""An exception raised by the RPC server.""" | |
pass | |
class RPC_FunctionFailed(RPC_Exception): | |
"""The RPC function called raised an exception.""" | |
pass | |
class RPC_UnknownFunction(RPC_Exception): | |
"""The function was not exposed to the server.""" | |
pass | |
class RPC_FailedToRetry(RPC_Exception): | |
"""Raised when ZMQ has failed to send or recv a message more than 3 times.""" | |
pass | |
def maybe_raise_exception(result): | |
"""Check the returned result for an exception | |
returned from the server. | |
""" | |
if isinstance(result, RPC_FunctionFailed): | |
exc_type, exc_val, exc_tb = result.args | |
hr = '-'*100 | |
print('\n%s\n' % hr) | |
print(' * RPC Function Failed:\n\n') | |
print(''.join(exc_tb)) | |
print('') | |
print( | |
" * The remote function raised the following exception:\n\n "\ | |
"%s: %s\n%s\n\n" % (exc_type, exc_val, hr) | |
) | |
raise RPC_Exception("An exception was raised by the RPC function.") | |
elif isinstance(result, Exception): | |
raise Exception(str(result)) | |
def serialize(data): | |
return pickle.dumps(data) | |
def deserialize(data): | |
return pickle.loads(data) | |
def compress(data): | |
return zlib.compress(data) | |
def decompress(data): | |
return zlib.decompress(data) | |
def encode_msg(data): | |
msg = compress(serialize(data)) | |
return msg | |
def decode_msg(msg): | |
data = deserialize(decompress(msg)) | |
maybe_raise_exception(data) | |
return data | |
class RPC_Client(): | |
def __init__(self, server, debug=False): | |
self._context = None | |
self._socket = None | |
self._server = None | |
self._debug = debug | |
self._init_socket(server) | |
def _init_socket(self, server): | |
self._server = server | |
self._address = SERVERS[self._server] | |
self._context = zmq.Context() | |
functions = self._call_remote_fn('fn_list') | |
for fn_name in functions: | |
partial_fn = partial(self._call_remote_fn, fn_name) | |
partial_fn.__name__ = fn_name | |
self.__dict__[fn_name] = \ | |
partial_fn | |
def _send(self, socket, data): | |
zmq_send(socket, data, warn_on_retry=self._debug) | |
def _recv(self, socket): | |
msg = socket.recv() | |
return decode_msg(msg) | |
def _request(self, fn_call): | |
socket = self._context.socket(zmq.REQ) | |
identity = u"rpc-client-{}"\ | |
.format(random.randint(10000, 99999))\ | |
.encode('ascii') | |
socket.setsockopt(zmq.IDENTITY, identity) | |
socket.connect(self._address) | |
try: | |
self._send(socket, fn_call) | |
response = self._recv(socket) | |
except: | |
raise | |
finally: | |
socket.close() | |
return response | |
def _call_remote_fn(self, fn_name, *args, **kwargs): | |
fn_call = (fn_name, args, kwargs) | |
return self._request(fn_call) | |
def __getattr__(self, attr): | |
if not attr in self.__dict__: | |
partial_fn = partial(self._call_remote_fn, attr) | |
partial_fn.__name__ = attr | |
return partial_fn | |
else: | |
return self.__dict__[attr] | |
def zmq_send(socket, msg, warn_on_retry=False): | |
retries = 0 | |
encoded_msg = encode_msg(msg) | |
while True: | |
try: | |
socket.send(encoded_msg, zmq.NOBLOCK) | |
except zmq.error.Again as e: | |
retries += 1 | |
if retries > NUM_CLIENT_RETRIES: | |
raise RPC_FailedToRetry( | |
"PyZMQ send raised `zmq.error.Again` more than %s times. "\ | |
"You may be trying to send too many messages at once."\ | |
% NUM_CLIENT_RETRIES | |
) | |
else: | |
delay = CLIENT_RETRY_DELAY * retries | |
if warn_on_retry: | |
print( | |
"Warning: ZMQ failed to send. Retrying after %sms..."\ | |
% int(delay * 1000) | |
) | |
time.sleep(delay) | |
else: | |
break | |
def worker_routine(identity, functions, backend_address): | |
context = zmq.Context() | |
socket = context.socket(zmq.REQ) | |
identity = u"rpc-worker-{}".format(identity).encode('ascii') | |
socket.setsockopt(zmq.IDENTITY, identity) | |
socket.connect(backend_address) | |
socket.send(b"READY") | |
for client, fn_call in poll_socket(socket, timeout=100): | |
dispatch(client, socket, functions, fn_call) | |
def poll_socket(socket, timeout): | |
poller = zmq.Poller() | |
poller.register(socket, zmq.POLLIN) | |
while True: | |
sockets = dict(poller.poll(timeout)) | |
if sockets.get(socket) == zmq.POLLIN: | |
client, empty, msg = socket.recv_multipart() | |
data = decode_msg(msg) | |
else: | |
identity = None | |
data = None | |
if not (data is None): | |
yield client, data | |
def dispatch(client, socket, functions, fn_call): | |
"""Call the function by name from the list of functions registered via | |
the `expose`. | |
""" | |
fn_name, args, kwargs = fn_call | |
if fn_name == 'fn_list': | |
result = list(functions.keys()) | |
else: | |
if fn_name in functions: | |
rpc_fn = functions[fn_name] | |
result = execute_rpc_fn(rpc_fn, args, kwargs) | |
else: | |
result = RPC_UnknownFunction( | |
"The function %s does not exist on the server. "\ | |
"Use `RPC_Server.expose()` to register a function." | |
% fn_name | |
) | |
socket.send(client, zmq.SNDMORE) | |
socket.send(b"", zmq.SNDMORE) | |
zmq_send(socket, result) | |
def execute_rpc_fn(rpc_fn, args, kwargs): | |
try: | |
result = rpc_fn(*args, **kwargs) | |
except Exception: | |
exc_type, exc_val, exc_tb = sys.exc_info() | |
result = RPC_FunctionFailed( | |
str(exc_type), str(exc_val), format_tb(exc_tb) | |
) | |
return result | |
class RPC_Server(): | |
def __init__(self, server_name, debug=False): | |
self._server_name = server_name | |
self._debug = debug | |
self._frontend_address = SERVERS[server_name] | |
self._backend_address = b"inproc://backend" | |
# self._backend_address = b"tcp://127.0.0.1:8001" | |
self._functions = {} | |
if not server_name in SERVERS: | |
raise RPC_ServerError( | |
"Server <%s> is not in the SERVERS dictionary. Add it in "\ | |
"zmq_rpc.py so that clients can find the server by name."\ | |
% server | |
) | |
def expose(self, fn): | |
self._functions[fn.__name__] = fn | |
return fn | |
def _start_workers(self, num_workers): | |
for identity in range(num_workers): | |
worker = Thread( | |
target=worker_routine, | |
args=( | |
identity, | |
self._functions, | |
self._backend_address | |
) | |
) | |
worker.daemon = True | |
worker.start() | |
def _proxy_loop(self, frontend, backend): | |
workers = [] | |
poller = zmq.Poller() | |
# Only poll for requests from backend until workers are available | |
poller.register(backend, zmq.POLLIN) | |
while True: | |
# Timeouts below 50ms seem to poll faster than messages can be sent. | |
# We have to have a timeout in order to Ctrl+C on Windows. | |
sockets = dict(poller.poll(timeout=50)) | |
if backend in sockets: | |
# Handle worker activity on the backend | |
request = backend.recv_multipart() | |
worker, empty, client = request[:3] | |
if not workers: | |
# Poll for clients now that a worker is available | |
poller.register(frontend, zmq.POLLIN) | |
workers.append(worker) | |
if client != b"READY" and len(request) > 3: | |
# If client reply, send rest back to frontend | |
empty, reply = request[3:] | |
frontend.send_multipart([client, b"", reply]) | |
if frontend in sockets: | |
# Get next client request, route to last-used worker | |
client, empty, request = frontend.recv_multipart() | |
worker = workers.pop(0) | |
backend.send_multipart([worker, b"", client, b"", request]) | |
if not workers: | |
# Don't poll clients if no workers are available | |
poller.unregister(frontend) | |
return empty_polls | |
def start(self, num_workers=4): | |
# Prepare our context and sockets | |
context = zmq.Context.instance() | |
# Socket to talk to clients | |
frontend = context.socket(zmq.ROUTER) | |
print(self._frontend_address) | |
frontend.bind(self._frontend_address) | |
# Socket to talk to workers | |
backend = context.socket(zmq.ROUTER) | |
backend.bind(self._backend_address) | |
# Launch pool of worker threads | |
self._start_workers(num_workers) | |
print(" * Server %s listening at: %s" \ | |
% (self._server_name, self._frontend_address)) | |
print(" - Press <Ctrl+C> to halt the server..") | |
try: | |
self._proxy_loop(frontend, backend) | |
except KeyboardInterrupt: | |
print(" * Halting..") | |
finally: | |
frontend.close() | |
backend.close() | |
context.term() | |
if __name__ == '__main__': | |
rpc_server = RPC_Server('test.server') | |
@rpc_server.expose | |
def add_one(x): | |
return x + 1 | |
def start_rpc_server(rpc_server): | |
rpc_server.start() | |
def start_client(): | |
rpc_client = RPC_Client('test.server') | |
assert rpc_client.add_one(1) == 2 | |
print("add_one(1) == 2") | |
t1 = Thread(target=start_rpc_server, args=(rpc_server,)) | |
t1.daemon = True | |
t1.start() | |
t2 = Thread(target=start_client) | |
t2.daemon = True | |
t2.start() | |
t2.join() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi @wamdam! I've made it MIT licensed with the intent that you can use it however you wish. This was just an experimet to understand PyZMQ on Windows, so it hasn't been tested much, was written many years ago, and was never used in any production capacity. I hope it's useful to you in some capacity though!