Created
May 14, 2017 16:35
-
-
Save thehesiod/67c4f1ba3596975671a4727c9c4e0df7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3.5 | |
import zmq.asyncio | |
import zmq | |
import asyncio | |
import pickle | |
import logging | |
import tempfile | |
import traceback | |
import multiprocessing | |
# NOTE: setting this to true causes the Terminators to get flushed out | |
_SEND_EMPTYMSG = False | |
_SEND_BUFFER_SIZE = 16 * 1024 | |
class Terminator: | |
def __init__(self): | |
# self._dummy = 'a' * _SEND_BUFFER_SIZE | |
pass | |
class _EmptyMsg: | |
def __init__(self): | |
self.dummy = 'a' * _SEND_BUFFER_SIZE | |
_empty_msg = _EmptyMsg() | |
_empty_msg_bytes = pickle.dumps(_empty_msg, pickle.HIGHEST_PROTOCOL) | |
class SocketCtx: | |
def __init__(self, socket_type, high_water_mark: int, socket_name: str, connect_method_name: str, zmq_ctx: zmq.asyncio.Context, send_empty_msg=True, rcvtimeout=None): | |
""" | |
Creates a ZMQ socket context object which ensures that `Terminator` is received last. | |
:param socket_type: type of socket | |
:param high_water_mark: high water mark | |
:param socket_name: name to pass to connect_method | |
:param connect_method_name: either "bind" or "connect" | |
:param zmq_ctx: zmq.asyncio.Context to use (will not destroy during aexit). Pass None to have this class construct and own the context. | |
:param send_empty_msg: set to True if you would like to flush the socket during `aexit` with _EmptyMsg | |
""" | |
self._zmq_ctx = zmq_ctx | |
self._socket = None | |
self._socket_type = socket_type | |
self._socket_name = socket_name | |
self._high_water_mark = high_water_mark | |
self._connect_method_name = connect_method_name | |
self._send_empty_msg = _SEND_EMPTYMSG and send_empty_msg | |
self._rcvtimeout = rcvtimeout | |
@property | |
def name(self): | |
return self._socket_name | |
async def __aenter__(self): | |
assert not self._socket | |
self._socket = self._zmq_ctx.socket(self._socket_type) | |
assert self._socket | |
self._socket.RCVHWM = self._socket.SNDHWM = self._high_water_mark | |
if self._rcvtimeout is not None: | |
self._socket.RCVTIMEO = self._rcvtimeout | |
if self._send_empty_msg: | |
self._socket.SNDBUF = _SEND_BUFFER_SIZE | |
getattr(self._socket, self._connect_method_name)(self._socket_name) | |
return self | |
async def __aexit__(self, exc_type, exc_val, exc_tb): | |
try: | |
if self._socket and self._socket_type == zmq.PUSH and self._send_empty_msg: | |
for i in range(self._high_water_mark + 1): | |
await self.send(_empty_msg_bytes) | |
finally: | |
if self._socket: | |
self._socket.close() | |
assert self._socket.closed | |
self._socket = None | |
async def recv(self, *args, **kwargs): | |
return await self._socket.recv(*args, **kwargs) | |
async def _recv_pyobj(self, *args, **kwargs): | |
# this method will filter out _EmptyMsgs | |
while True: | |
pyobj = await self._socket.recv_pyobj(*args, **kwargs) | |
if pyobj.__class__ is _EmptyMsg: | |
continue | |
return pyobj | |
async def recv_pyobj(self, *args, **kwargs): | |
return await self._recv_pyobj(*args, **kwargs) | |
async def send(self, data): | |
return await self._socket.send(data) | |
def sender(socket_name, num_iterations): | |
loop = zmq.asyncio.ZMQEventLoop() | |
asyncio.set_event_loop(loop) | |
obj = {'a': 'a' * 229114} | |
obj_bytes = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) | |
term_bytes = pickle.dumps(Terminator(), protocol=pickle.HIGHEST_PROTOCOL) | |
try: | |
zmq_ctx = zmq.asyncio.Context() | |
async def doit(): | |
async with SocketCtx(zmq.PUSH, 2, socket_name, 'connect', zmq_ctx) as zmq_socket: | |
for num in range(num_iterations): | |
print("Sndr iter: {}/{}".format(num, num_iterations)) | |
for i in range(5): | |
await zmq_socket.send(obj_bytes) | |
await zmq_socket.send(term_bytes) | |
loop.run_until_complete(doit()) | |
zmq_ctx.destroy() | |
print("worker finished sending") | |
except: | |
traceback.print_exc() | |
raise | |
async def receiver(connected_event: asyncio.Event, socket_name, num_terminators): | |
zmq_ctx = zmq.asyncio.Context() | |
try: | |
async with SocketCtx(zmq.PULL, 2, socket_name, 'bind', zmq_ctx) as zmq_socket: | |
connected_event.set() | |
total_received = 0 | |
while True: | |
pyobj = await zmq_socket.recv_pyobj() | |
print("Rcv:", pyobj.__class__, total_received, num_terminators) | |
total_received += 1 | |
if isinstance(pyobj, Terminator): | |
num_terminators -= 1 | |
if num_terminators == 0: | |
break | |
await asyncio.sleep(0.1) | |
zmq_ctx.destroy() | |
print("total received:", total_received) | |
return total_received | |
except: | |
traceback.print_exc() | |
raise | |
async def ooo_test(loop): | |
with tempfile.TemporaryDirectory() as temp_dir: | |
socket_name = "ipc://" + temp_dir + "/foobar" | |
connected_event = asyncio.Event() | |
num_procs = 6 | |
num_iters_per_proc = 4 | |
num_sends = num_procs * num_iters_per_proc | |
recvr_task = asyncio.ensure_future(receiver(connected_event, socket_name, num_sends)) | |
await connected_event.wait() | |
procs = [ | |
multiprocessing.Process(target=sender, args=[socket_name, num_iters_per_proc]) | |
for _ in range(num_procs) | |
] | |
for p in procs: | |
p.start() | |
print("waiting for sends to complete down") | |
for p in procs: | |
await loop.run_in_executor(None, p.join) | |
print("waiting to receive all terminators") | |
await recvr_task | |
print("done") | |
if __name__ == '__main__': | |
_loop = zmq.asyncio.ZMQEventLoop() | |
asyncio.set_event_loop(_loop) | |
logging.captureWarnings(True) | |
_loop.run_until_complete(ooo_test(_loop)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment