Skip to content

Instantly share code, notes, and snippets.

@trevorbernard
Forked from claws/authentication.py
Created December 8, 2013 18:31
Show Gist options
  • Save trevorbernard/7861712 to your computer and use it in GitHub Desktop.
Save trevorbernard/7861712 to your computer and use it in GitHub Desktop.
'''
An authentication module for pyzmq modelled on zauth from czmq.
The certificates used for CURVE authentication are assumed to be identical to
those generated by czmq.
'''
import glob
import json
import os
from threading import Thread
import zmq
from zmq.utils import z85
from zmq.eventloop.ioloop import IOLoop
from zmq.eventloop.zmqstream import ZMQStream
CURVE_ALLOW_ANY = '*'
def load_certificate(filename):
''' Return the public and private keys from a zcert file '''
public_key = None
secret_key = None
with open(filename, 'r') as f:
lines = f.readlines()
lines = filter(None, lines)
lines = filter(lambda x: not x.startswith('#'), lines)
lines = [x.strip() for x in lines]
for line in lines:
if line.startswith('public-key'):
public_key = line.split(" = ")[1].strip().replace('"', '')
if line.startswith('secret-key'):
secret_key = line.split(" = ")[1].strip().replace('"', '')
return public_key, secret_key
def load_certificates(location):
''' Load zcert certificates from a directory '''
certs = {}
if os.path.isdir(location):
# Follow czmq pattern of public keys stored in *.key files.
glob_string = os.path.join(location, "*.key")
cert_files = glob.glob(glob_string)
for cert_file in cert_files:
public_key, _ = load_certificate(cert_file)
if public_key:
certs[public_key] = 'OK'
return certs
class AuthAgentThread(Thread):
''' Thread in which ZAP authentication actually happens '''
def __init__(self, context, endpoint, verbose=False):
super(AuthAgentThread, self).__init__()
self.context = context
self.verbose = verbose
self.allow_any = False
self.zap = None
self.whitelist = []
self.blacklist = []
self.passwords = {}
self.certs = {}
# create a socket to communicate back to main thread.
self.pipe = context.socket(zmq.PAIR)
self.pipe.connect(endpoint)
def run(self):
''' Start the Authentication Agent thread task '''
# Create ZAP handler and get ready for requests
self.zap = self.context.socket(zmq.REP)
self.zap.bind("inproc://zeromq.zap.01")
poller = zmq.Poller()
poller.register(self.pipe, zmq.POLLIN)
poller.register(self.zap, zmq.POLLIN)
while True:
try:
socks = dict(poller.poll())
except zmq.ZMQError:
break # interrupted
if self.pipe in socks and socks[self.pipe] == zmq.POLLIN:
terminate = self._handle_pipe()
if terminate:
break
if self.zap in socks and socks[self.zap] == zmq.POLLIN:
self._authenticate()
self.pipe.linger = 1
self.pipe.close()
self.zap.linger = 1
self.zap.close()
def _send_zap_reply(self, sequence, status_code, status_text):
'''
Send a ZAP reply to the handler socket.
'''
if self.verbose:
print "I: ZAP reply %s" % status_code
reply = [b"1.0", sequence, status_code, status_text, b"", b""]
self.zap.send_multipart(reply)
def _handle_pipe(self):
'''
Handle a message from front-end API.
'''
terminate = False
# Get the whole message off the pipe in one go
msg = self.pipe.recv_multipart()
if msg is None:
terminate = True
return terminate
command = msg[0]
if self.verbose:
print "I: auth received API command %s" % command
if command == 'ALLOW':
address = msg[1]
if address not in self.whitelist:
self.whitelist.append(address)
self.pipe.send(b'OK')
elif command == 'DENY':
address = msg[1]
if address not in self.blacklist:
self.blacklist.append(address)
self.pipe.send(b'OK')
elif command == 'PLAIN':
# For now we don't do anything with domains
_domain = msg[1]
json_passwords = msg[2]
self.passwords = json.loads(json_passwords)
self.pipe.send(b'OK')
elif command == 'CURVE':
# For now we don't do anything with domains
_domain = msg[1]
# If location is CURVE_ALLOW_ANY, allow all clients. Otherwise
# treat location as a directory that holds the certificates.
location = msg[2]
if location == CURVE_ALLOW_ANY:
self.allow_any = True
else:
self.allow_any = False
if os.path.isdir(location):
self.certs = load_certificates(location)
else:
if self.verbose:
print "E: Invalid CURVE certs location: %s" % location
self.pipe.send(b'OK')
elif command == 'VERBOSE':
enabled = msg[1]
self.verbose = enabled == '1'
self.pipe.send(b'OK')
elif command == 'TERMINATE':
terminate = True
self.pipe.send(b'OK')
else:
print "E: invalid auth command from API: %s\n" % command
return terminate
def _authenticate_plain(self, username, password):
'''
Perform ZAP authentication check for PLAIN mechanism
'''
allowed = False
if self.passwords:
if username in self.passwords:
if password == self.passwords[username]:
allowed = True
if self.verbose:
if allowed:
print "I: ALLOWED (PLAIN) username=%s password=%s" % \
(username, password)
else:
print "I: DENIED (PLAIN) username=%s password=%s" % \
(username, password)
else:
if self.verbose:
print "I: DENIED (PLAIN) no passwords defined"
return allowed
def _authenticate_curve(self, client_key):
'''
Perform ZAP authentication check for CURVE mechanism
'''
allowed = False
if self.allow_any:
allowed = True
if self.verbose:
print "I: ALLOWED (CURVE allow any client)"
else:
if self.certs:
# convert binary key to z85 text
z85_client_key = z85.encode(client_key)
if z85_client_key in self.certs:
if self.verbose:
print "I: ALLOWED (CURVE) client_key=%s" % z85_client_key
allowed = True
else:
if self.verbose:
print "I: DENIED (CURVE) client_key=%s" % z85_client_key
return allowed
def _authenticate(self):
'''
Perform ZAP authentication.
'''
msg = self.zap.recv_multipart()
if not msg: return
version, sequence, domain, address, identity, mechanism = msg[:6]
if (version != b"1.0"):
self._send_zap_reply(sequence, b"400", b"Invalid version")
return
# Check if address is explicitly whitelisted or blacklisted
allowed = False
denied = False
if self.whitelist:
if address in self.whitelist:
allowed = True
if self.verbose:
print "I: PASSED (whitelist) address=%s" % address
else:
denied = True
if self.verbose:
print "I: DENIED (not in whitelist) address=%s" % address
elif self.blacklist:
if address in self.blacklist:
denied = True
if self.verbose:
print "I: DENIED (blacklist) address=%s" % address
else:
allowed = True
if self.verbose:
print "I: PASSED (not in blacklist) address=%s" % address
# Mechanism-specific checks
if not denied:
if mechanism == b'NULL' and not allowed:
# For NULL, we allow if the address wasn't blacklisted
if self.verbose:
print "I: ALLOWED (NULL)"
allowed = True
elif mechanism == b'PLAIN':
# For PLAIN, even a whitelisted address must authenticate
username, password = msg[6:]
allowed = self._authenticate_plain(username, password)
elif mechanism == b'CURVE':
# For CURVE, even a whitelisted address must authenticate
key = msg[6]
allowed = self._authenticate_curve(key)
if allowed:
self._send_zap_reply(sequence, b"200", b"OK")
else:
self._send_zap_reply(sequence, b"400", b"NO ACCESS")
class Authenticator(object):
'''
An Authenticator object takes over authentication for all incoming
connections in its context.
Note:
- libzmq provides four levels of security: default NULL (which zauth does
not see), and authenticated NULL, PLAIN, and CURVE, which zauth can see.
- until you add policies, all incoming NULL connections are allowed
(classic ZeroMQ behavior), and all PLAIN and CURVE connections are denied.
All work is done by a background thread, the "agent", which we talk
to over a pipe. This lets the agent do work asynchronously in the
background while our application does other things. This is invisible to
the caller, who sees a classic API.
'''
def __init__(self, context, verbose=False):
if zmq.zmq_version_info() < (4,0):
raise Exception("Security is only available in libzmq >= 4.0")
self.context = context
self.pipe = None
self.pipe_endpoint = "inproc://{}.inproc".format(id(self))
self.thread = None
self.start(verbose)
def allow(self, address):
'''
Allow (whitelist) a single IP address. For NULL, all clients from this
address will be accepted. For PLAIN and CURVE, they will be allowed to
continue with authentication. You can call this method multiple times
to whitelist multiple IP addresses. If you whitelist a single address,
any non-whitelisted addresses are treated as blacklisted.
'''
self.pipe.send_multipart([b'ALLOW', address])
def deny(self, address):
'''
Deny (blacklist) a single IP address. For all security mechanisms, this
rejects the connection without any further authentication. Use either a
whitelist, or a blacklist, not not both. If you define both a whitelist
and a blacklist, only the whitelist takes effect.
'''
self.pipe.send_multipart([b'DENY', address])
def verbose(self, enabled):
'''
Enable verbose tracing of commands and activity.
'''
self.pipe.send_multipart([b'VERBOSE', b'1' if enabled else b'0'])
def configure_plain(self, domain='*', passwords=None):
'''
Configure PLAIN authentication for a given domain. PLAIN authentication
uses a plain-text password file. To cover all domains, use "*".
You can modify the password file at any time; it is reloaded automatically.
'''
if passwords:
if isinstance(passwords, dict):
passwords = json.dumps(passwords)
self.pipe.send_multipart([b'PLAIN', domain, passwords])
def configure_curve(self, domain='*', location=None):
'''
Configure CURVE authentication for a given domain. CURVE authentication
uses a directory that holds all public client certificates, i.e. their
public keys. The certificates must be in zcert_save () format.
To cover all domains, use "*".
You can add and remove certificates in that directory at any time.
To allow all client keys without checking, specify CURVE_ALLOW_ANY for
the location.
'''
self.pipe.send_multipart([b'CURVE', domain, location])
def start(self, verbose=False):
'''
Start performing ZAP authentication
'''
# create a socket to communicate with auth thread.
self.pipe = self.context.socket(zmq.PAIR)
self.pipe.linger = 1
self.pipestream = ZMQStream(self.pipe, IOLoop.instance())
self.pipestream.on_recv(self._on_message)
self.pipestream.bind(self.pipe_endpoint)
self.thread = AuthAgentThread(self.context,
self.pipe_endpoint, verbose=verbose)
self.thread.start()
def stop(self):
'''
Stop performing ZAP authentication
'''
if self.pipe:
self.pipe.send(b'TERMINATE')
if self.is_alive():
self.thread.join()
self.thread = None
self.pipe.close()
self.pipe = None
self.pipestream = None
def is_alive(self):
''' Is the Auth thread currently running ? '''
if self.thread and self.thread.is_alive():
return True
return False
def __del__(self):
self.stop()
def _on_message(self, msg):
'''
Process a message from the Auth thread
'''
status = msg[0]
if status != b"OK":
print "E: status from auth thread indicates error: %s" % status
if __name__ == '__main__':
def can_connect(server, client):
result = False
iface = 'tcp://127.0.0.1'
port = server.bind_to_random_port(iface)
client.connect("%s:%i" % (iface, port))
msg = ["Hello World"]
server.send_multipart(msg)
poller = zmq.Poller()
poller.register(client, zmq.POLLIN)
socks = dict(poller.poll(100))
if client in socks and socks[client] == zmq.POLLIN:
rcvd_msg = client.recv_multipart()
result = rcvd_msg == msg
return result
context = zmq.Context()
auth = Authenticator(context, verbose=True)
# A default NULL connection should always succeed, and not
# go through our authentication infrastructure at all.
server = context.socket(zmq.PUSH)
client = context.socket(zmq.PULL)
assert can_connect(server, client)
client.close()
server.close()
print ""
# When we set a domain on the server, we switch on authentication
# for NULL sockets, but with no policies, the client connection
# will still be allowed.
server = context.socket(zmq.PUSH)
server.zap_domain = 'global'
client = context.socket(zmq.PULL)
assert can_connect(server, client)
client.close()
server.close()
print ""
# blacklist 127.0.0.1, connection should fail
auth.deny('127.0.0.1')
server = context.socket(zmq.PUSH)
server.zap_domain = 'global'
client = context.socket(zmq.PULL)
assert not can_connect(server, client)
client.close()
server.close()
print ""
# Whitelist 127.0.0.1, which overrides the blacklist
auth.allow('127.0.0.1')
server = context.socket(zmq.PUSH)
server.zap_domain = 'global'
client = context.socket(zmq.PULL)
assert can_connect(server, client)
client.close()
server.close()
print ""
# attempt PLAIN authentication - without configuring server for PLAIN
server = context.socket(zmq.PUSH)
server.plain_server = True
client = context.socket(zmq.PULL)
client.plain_username = 'admin'
client.plain_password = 'Password'
assert not can_connect(server, client)
client.close()
server.close()
print ""
# try PLAIN authentication
server = context.socket(zmq.PUSH)
server.plain_server = True
client = context.socket(zmq.PULL)
client.plain_username = 'admin'
client.plain_password = 'Password'
auth.configure_plain(domain='*', passwords={'admin': 'Password'})
assert can_connect(server, client)
client.close()
server.close()
print ""
# attempt PLAIN using bogus credentials
server = context.socket(zmq.PUSH)
server.plain_server = True
client = context.socket(zmq.PULL)
client.plain_username = 'admin'
client.plain_password = 'Bogus'
assert not can_connect(server, client)
client.close()
server.close()
print ""
# test CURVE authentication
server_public, server_secret = load_certificate('.certs_private/server.key_secret')
client_public, client_secret = load_certificate('.certs_private/client.key_secret')
# test without setting up any authentication
server = context.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = context.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
assert not can_connect(server, client)
client.close()
server.close()
print ""
# test CURVE_ALLOW_ANY
auth.configure_curve(domain='*', location=CURVE_ALLOW_ANY)
server = context.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = context.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
assert can_connect(server, client)
client.close()
server.close()
print ""
# Test full client authentication using certificates
auth.configure_curve(domain='*', location='.certs_public')
server = context.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = context.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
assert can_connect(server, client)
client.close()
server.close()
print ""
# Remove authenticator and check a normal connection works
auth.stop()
del auth
print ""
server = context.socket(zmq.PUSH)
client = context.socket(zmq.PULL)
assert can_connect(server, client)
client.close()
server.close()
context.term()
# **** Generated on 2013-12-08 12:24:52 by CZMQ ****
# ZeroMQ CURVE Public Certificate
# Exchange securely, or use a secure mechanism to verify the contents
# of this file after exchange. Store public certificates in your home
# directory, in the .curve subdirectory.
metadata
curve
public-key = "JcS5XDo0YrL<Q4At!O9l9of8K4aeUw9o7t*6pW!."
# **** Generated on 2013-12-08 12:24:52 by CZMQ ****
# ZeroMQ CURVE **Secret** Certificate
# DO NOT PROVIDE THIS FILE TO OTHER USERS nor change its permissions.
metadata
curve
public-key = "JcS5XDo0YrL<Q4At!O9l9of8K4aeUw9o7t*6pW!."
secret-key = "(2SKOp#IZ6sef/1-t$Ltezd{w^sU*S6U:gNk6b0*"
# **** Generated on 2013-12-08 12:24:52 by CZMQ ****
# ZeroMQ CURVE Public Certificate
# Exchange securely, or use a secure mechanism to verify the contents
# of this file after exchange. Store public certificates in your home
# directory, in the .curve subdirectory.
metadata
curve
public-key = "p5Rm=1]QdW^>Z?dcqvI.vNq1yau:wl&$/rRd[rbn"
# **** Generated on 2013-12-08 12:24:52 by CZMQ ****
# ZeroMQ CURVE **Secret** Certificate
# DO NOT PROVIDE THIS FILE TO OTHER USERS nor change its permissions.
metadata
curve
public-key = "p5Rm=1]QdW^>Z?dcqvI.vNq1yau:wl&$/rRd[rbn"
secret-key = "QgjNKGkF5(OPt1QrfTVPphex4gX0e^RIRJl6!8R5"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment