Skip to content

Instantly share code, notes, and snippets.

@mortbauer
Forked from mivade/zmq_auth.py
Created September 5, 2023 06:18
Show Gist options
  • Save mortbauer/6005ae8afabe6ef438957e63b331fbc6 to your computer and use it in GitHub Desktop.
Save mortbauer/6005ae8afabe6ef438957e63b331fbc6 to your computer and use it in GitHub Desktop.
ZeroMQ Curve authentication demo
"""Simple demonstration of using ZMQ's Curve authentication.
This demo is adapted from the examples given in the `PyZMQ repository`__. Key
differences include:
* Using ``setsockopt`` to set Curve parameters instead of setting attributes
directly (help out your IDE!)
* Integration with ``asyncio``
__ https://github.com/zeromq/pyzmq/tree/master/examples
"""
from abc import ABC, abstractmethod
import asyncio
from contextlib import AbstractContextManager
import os.path
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Dict, Optional, Tuple, Union
import zmq
from zmq.asyncio import Context
import zmq.auth
from zmq.auth.asyncio import AsyncioAuthenticator
def generate_keys(key_dir: str) -> Dict[str, Tuple[str, str]]:
"""Generate all public/private keys needed for this demo.
Parameters
----------
key_dir
Directory to write keys to.
"""
s_pub, s_sec = zmq.auth.create_certificates(key_dir, "server")
c_pub, c_sec = zmq.auth.create_certificates(key_dir, "client")
return {"server": (s_pub, s_sec), "client": (c_pub, c_sec)}
class BaseServer(AbstractContextManager):
"""Base ZMQ server with authentication support.
Parameters
----------
address
ZMQ address string to listen on.
secret_key_path
Path to the server secret key file.
client_key_dir
Path to the directory containing authorized client public keys. When
not given, accept connections from any client that knows the server's
public key.
ctx
A :class:`Context`. If not given, one will be created.
socket_type
Type of socket to create. If not given, ``zmq.REP`` will be used.
"""
def __init__(
self,
address: str,
secret_key_path: Union[str, Path],
client_key_dir: Optional[Union[str, Path]] = None,
ctx: Optional[Context] = None,
socket_type: Optional[int] = zmq.REP
):
if not address.startswith("tcp://"):
raise ValueError("CurveZMQ only works over TCP")
self.address = address
self.socket_type = socket_type
self.ctx = ctx or Context.instance()
self._secret_key_file = secret_key_path
assert os.path.isfile(self._secret_key_file)
if client_key_dir is not None:
self._client_key_dir = client_key_dir
assert os.path.isdir(self._client_key_dir)
else:
self._client_key_dir = None
auth_location = (
str(client_key_dir)
if client_key_dir is not None
else zmq.auth.CURVE_ALLOW_ANY
)
# Configure the authenticator
self.auth = AsyncioAuthenticator(context=self.ctx)
self.auth.configure_curve(domain="*", location=auth_location)
self.auth.allow("127.0.0.1")
self.auth.start()
# Configure the listening socket
self.socket = self.ctx.socket(self.socket_type)
keys = zmq.auth.load_certificate(self._secret_key_file)
self.socket.setsockopt(zmq.CURVE_PUBLICKEY, keys[0])
self.socket.setsockopt(zmq.CURVE_SECRETKEY, keys[1])
self.socket.setsockopt(zmq.CURVE_SERVER, True)
self.socket.bind(self.address)
def __exit__(self, *_exc):
self.auth.stop()
@abstractmethod
async def run(self):
"""Implement this method to send and/or receive messages."""
class EchoServer(BaseServer):
"""A simple echoing service."""
async def run(self):
with self:
while True:
msg = await self.socket.recv()
await self.socket.send(msg)
if msg == b"quit":
print("Server exiting upon request")
break
class BaseClient(ABC):
"""Base (possibly) authenticated client class.
address
ZMQ address for the server.
server_public_key_path
Path to the server's public key.
secret_key_path
Path to the client's secret key.
ctx
Optional :class:`Context`.
"""
def __init__(
self,
address: str,
server_public_key_path: Union[str, Path],
secret_key_path: Union[str, Path],
ctx: Optional[Context] = None,
):
self.ctx = ctx or Context.instance()
self.socket = self.ctx.socket(zmq.REQ)
self.address = address
# Configure client keys
keys = zmq.auth.load_certificate(secret_key_path)
self.socket.setsockopt(zmq.CURVE_PUBLICKEY, keys[0])
self.socket.setsockopt(zmq.CURVE_SECRETKEY, keys[1])
# Load the server public key and register with the socket
server_key, _ = zmq.auth.load_certificate(server_public_key_path)
self.socket.setsockopt(zmq.CURVE_SERVERKEY, server_key)
self.socket.connect(self.address)
@abstractmethod
async def run(self) -> None:
"""Implement this coroutine to communicate with the server."""
class EchoClient(BaseClient):
"""A simple echo request client."""
async def run(self) -> None:
for i in range(10):
await self.socket.send(f"Hello, world {i}".encode())
result = await self.socket.recv()
print("Client received", result)
await asyncio.sleep(1)
await self.socket.send(b"quit")
await self.socket.recv()
async def main():
import logging
# Set debug logging so we can see zmq.auth's logs
logging.basicConfig(level=logging.DEBUG)
with TemporaryDirectory() as tempdir:
address = "tcp://127.0.0.1:9999"
keys = generate_keys(tempdir)
server = EchoServer(address, keys["server"][1], tempdir)
client = EchoClient(address, keys["server"][0], keys["client"][1])
await asyncio.gather(server.run(), client.run())
if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment