Skip to content

Instantly share code, notes, and snippets.

@thehesiod
Created May 14, 2017 16:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thehesiod/67c4f1ba3596975671a4727c9c4e0df7 to your computer and use it in GitHub Desktop.
Save thehesiod/67c4f1ba3596975671a4727c9c4e0df7 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3.5
import zmq.asyncio
import zmq
import asyncio
import pickle
import logging
import tempfile
import traceback
import multiprocessing
# NOTE: setting this to true causes the Terminators to get flushed out
_SEND_EMPTYMSG = False
_SEND_BUFFER_SIZE = 16 * 1024
class Terminator:
def __init__(self):
# self._dummy = 'a' * _SEND_BUFFER_SIZE
pass
class _EmptyMsg:
def __init__(self):
self.dummy = 'a' * _SEND_BUFFER_SIZE
_empty_msg = _EmptyMsg()
_empty_msg_bytes = pickle.dumps(_empty_msg, pickle.HIGHEST_PROTOCOL)
class SocketCtx:
def __init__(self, socket_type, high_water_mark: int, socket_name: str, connect_method_name: str, zmq_ctx: zmq.asyncio.Context, send_empty_msg=True, rcvtimeout=None):
"""
Creates a ZMQ socket context object which ensures that `Terminator` is received last.
:param socket_type: type of socket
:param high_water_mark: high water mark
:param socket_name: name to pass to connect_method
:param connect_method_name: either "bind" or "connect"
:param zmq_ctx: zmq.asyncio.Context to use (will not destroy during aexit). Pass None to have this class construct and own the context.
:param send_empty_msg: set to True if you would like to flush the socket during `aexit` with _EmptyMsg
"""
self._zmq_ctx = zmq_ctx
self._socket = None
self._socket_type = socket_type
self._socket_name = socket_name
self._high_water_mark = high_water_mark
self._connect_method_name = connect_method_name
self._send_empty_msg = _SEND_EMPTYMSG and send_empty_msg
self._rcvtimeout = rcvtimeout
@property
def name(self):
return self._socket_name
async def __aenter__(self):
assert not self._socket
self._socket = self._zmq_ctx.socket(self._socket_type)
assert self._socket
self._socket.RCVHWM = self._socket.SNDHWM = self._high_water_mark
if self._rcvtimeout is not None:
self._socket.RCVTIMEO = self._rcvtimeout
if self._send_empty_msg:
self._socket.SNDBUF = _SEND_BUFFER_SIZE
getattr(self._socket, self._connect_method_name)(self._socket_name)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
try:
if self._socket and self._socket_type == zmq.PUSH and self._send_empty_msg:
for i in range(self._high_water_mark + 1):
await self.send(_empty_msg_bytes)
finally:
if self._socket:
self._socket.close()
assert self._socket.closed
self._socket = None
async def recv(self, *args, **kwargs):
return await self._socket.recv(*args, **kwargs)
async def _recv_pyobj(self, *args, **kwargs):
# this method will filter out _EmptyMsgs
while True:
pyobj = await self._socket.recv_pyobj(*args, **kwargs)
if pyobj.__class__ is _EmptyMsg:
continue
return pyobj
async def recv_pyobj(self, *args, **kwargs):
return await self._recv_pyobj(*args, **kwargs)
async def send(self, data):
return await self._socket.send(data)
def sender(socket_name, num_iterations):
loop = zmq.asyncio.ZMQEventLoop()
asyncio.set_event_loop(loop)
obj = {'a': 'a' * 229114}
obj_bytes = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
term_bytes = pickle.dumps(Terminator(), protocol=pickle.HIGHEST_PROTOCOL)
try:
zmq_ctx = zmq.asyncio.Context()
async def doit():
async with SocketCtx(zmq.PUSH, 2, socket_name, 'connect', zmq_ctx) as zmq_socket:
for num in range(num_iterations):
print("Sndr iter: {}/{}".format(num, num_iterations))
for i in range(5):
await zmq_socket.send(obj_bytes)
await zmq_socket.send(term_bytes)
loop.run_until_complete(doit())
zmq_ctx.destroy()
print("worker finished sending")
except:
traceback.print_exc()
raise
async def receiver(connected_event: asyncio.Event, socket_name, num_terminators):
zmq_ctx = zmq.asyncio.Context()
try:
async with SocketCtx(zmq.PULL, 2, socket_name, 'bind', zmq_ctx) as zmq_socket:
connected_event.set()
total_received = 0
while True:
pyobj = await zmq_socket.recv_pyobj()
print("Rcv:", pyobj.__class__, total_received, num_terminators)
total_received += 1
if isinstance(pyobj, Terminator):
num_terminators -= 1
if num_terminators == 0:
break
await asyncio.sleep(0.1)
zmq_ctx.destroy()
print("total received:", total_received)
return total_received
except:
traceback.print_exc()
raise
async def ooo_test(loop):
with tempfile.TemporaryDirectory() as temp_dir:
socket_name = "ipc://" + temp_dir + "/foobar"
connected_event = asyncio.Event()
num_procs = 6
num_iters_per_proc = 4
num_sends = num_procs * num_iters_per_proc
recvr_task = asyncio.ensure_future(receiver(connected_event, socket_name, num_sends))
await connected_event.wait()
procs = [
multiprocessing.Process(target=sender, args=[socket_name, num_iters_per_proc])
for _ in range(num_procs)
]
for p in procs:
p.start()
print("waiting for sends to complete down")
for p in procs:
await loop.run_in_executor(None, p.join)
print("waiting to receive all terminators")
await recvr_task
print("done")
if __name__ == '__main__':
_loop = zmq.asyncio.ZMQEventLoop()
asyncio.set_event_loop(_loop)
logging.captureWarnings(True)
_loop.run_until_complete(ooo_test(_loop))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment