Skip to content

Instantly share code, notes, and snippets.

@akhilman
Created August 29, 2018 16:16
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save akhilman/6f4aa516a7317f36ebac427a9d392865 to your computer and use it in GitHub Desktop.
Save akhilman/6f4aa516a7317f36ebac427a9d392865 to your computer and use it in GitHub Desktop.
asynchron but without eventloop client for aiozmq rpc.
"""
Based on synchronous implementation of the aiozmq.rpc.RPCClient
https://gist.github.com/derfenix/f18e4a8f0ee9bad738c2b22106a3ad4d
"""
import functools
import logging
import os
import random
import struct
import sys
import time
from collections import ChainMap
from concurrent.futures import Future
from functools import partial
import zmq
from aiozmq.rpc.base import GenericError
from aiozmq.rpc.rpc import _default_error_table
from slivoglot.core.packer import PicklePacker
__all__ = ['RPCClient']
@functools.lru_cache()
def log():
return logging.getLogger(__name__)
class RPCFuture(Future):
def __init__(self, client, req_id):
super().__init__()
self.client = client
self.req_id = req_id
def result(self, timeout=None):
if not self.done():
self.client.poll_until_done(self, timeout)
assert self.done()
return super().result()
def exception(self, timeout=None):
if not self.done():
self.client.poll_until_done(self, timeout)
assert self.done()
return super().exception()
class RPCClient(object):
REQ_PREFIX = struct.Struct('=HH')
REQ_SUFFIX = struct.Struct('=Ld')
RESP = struct.Struct('=HHLd?')
def __init__(self, *, connect, timeout=None, error_table=None):
self.timeout = timeout
self.uri = connect
self.prefix = self.REQ_PREFIX.pack(os.getpid() % 0x10000,
random.randrange(0x10000))
self.packer = PicklePacker()
context = zmq.Context()
self.socket = context.socket(zmq.DEALER)
self.socket.connect(connect)
self.calls = {}
self._counter = 0
if error_table is None:
self.error_table = _default_error_table
else:
self.error_table = ChainMap(error_table, _default_error_table)
def __del__(self):
if self.socket:
self.socket.close()
def _new_id(self):
self._counter += 1
if self._counter > 0xffffffff:
self._counter = 0
return (self.prefix + self.REQ_SUFFIX.pack(self._counter, time.time()),
self._counter)
def __getattr__(self, item):
try:
return self.__getattribute__(item)
except AttributeError:
return partial(self.call, item)
def __call__(self, name, *args, **kwargs):
return self.call(name, *args, **kwargs)
def call(self, name: str, *args, **kwargs):
binary_name = name.encode('utf-8')
binary_args = self.packer.packb(args)
binary_kwargs = self.packer.packb(kwargs)
header, req_id = self._new_id()
self.socket.send_multipart(
[header, binary_name, binary_args, binary_kwargs])
fut = RPCFuture(self, req_id)
fut.set_running_or_notify_cancel()
self.calls[req_id] = fut
return fut
def poll(self, timeout=None):
if not self.calls:
return 0
if timeout:
timeout = int(timeout * 1000)
print(timeout)
count = self.socket.poll(timeout=timeout)
if not count:
return 0
for _ in range(count):
data = self.socket.recv_multipart()
self.msg_received(data)
return count
def poll_until_done(self, fut, timeout=None):
start_time = time.time()
while not fut.done():
if timeout:
poll_timeout = start_time + timeout - time.time()
if poll_timeout <= 0:
raise TimeoutError('Timeout')
else:
poll_timeout = None
self.poll(poll_timeout)
def _translate_error(self, exc_type, exc_args, exc_repr):
found = self.error_table.get(exc_type)
if found is None:
return GenericError(exc_type, exc_args, exc_repr)
else:
return found(*exc_args)
def msg_received(self, data):
try:
header, banswer = data
pid, rnd, req_id, timestamp, is_error = self.RESP.unpack(header)
answer = self.packer.unpackb(banswer)
except Exception:
log().critical("Cannot unpack %r", data, exc_info=sys.exc_info())
return
call = self.calls.pop(req_id, None)
if call is None:
log().critical("Unknown answer id: %d (%d %d %f %d) -> %s",
req_id, pid, rnd, timestamp, is_error, answer)
elif call.cancelled():
log().debug("The future for request #%08x has been cancelled, "
"skip the received result.", req_id)
else:
if is_error:
call.set_exception(self._translate_error(*answer))
else:
call.set_result(answer)
def connect_rpc(**kwargs):
return RPCClient(**kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment