Skip to content

Instantly share code, notes, and snippets.

@predakanga
Created August 8, 2014 02:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save predakanga/16664fb392389c65aaff to your computer and use it in GitHub Desktop.
Save predakanga/16664fb392389c65aaff to your computer and use it in GitHub Desktop.
from __future__ import print_function
import os, os.path, sys, threading, signal, time, collections
from Queue import Queue
from SocketServer import UnixStreamServer, StreamRequestHandler, ThreadingMixIn
from click import command, option
from ldif import LDIFWriter
from uuid import UUID, uuid5
import oursql
class GenericPool(object):
def __init__(self, min=3, max=5, preload=True):
self._semaphore = threading.BoundedSemaphore(max)
self._queue = Queue(max)
self._min = min
if preload:
for i in xrange(0, min):
self._fill_queue()
def acquire(self):
self._semaphore.acquire()
if self._queue.empty():
self._fill_queue()
return self._queue.get()
def release(self, obj):
if self._queue.qsize() > self._min:
self.discard(obj)
else:
self._queue.put(obj)
self._semaphore.release()
def _fill_queue(self):
obj = self.factory()
self._queue.put(obj)
def discard(self, obj):
self.dispose(obj)
self._semaphore.release()
def factory(self):
raise NotImplementedError
def dispose(self, obj):
raise NotImplementedError
class OurSqlPool(GenericPool):
def __init__(self, min=1, max=5, **args):
super(OurSqlPool, self).__init__(min, max, False)
self._connOpts = args
def factory(self):
return oursql.connect(**self._connOpts)
def dispose(self, conn):
conn.close()
class ThreadedUnixStreamServer(ThreadingMixIn, UnixStreamServer):
pass
class LdapUnbindException(Exception):
pass
connectionOptions = {
'user': 'root',
'passwd': 'btn-dev',
'host': 'localhost',
'db': 'gazelle'
}
sql = OurSqlPool(**connectionOptions)
nsUuid = UUID('01168943-44c8-434c-a369-257b3e0ecafb')
class LdapRequestHandler(StreamRequestHandler):
def unbind(self, params):
raise LdapUnbindException
def dispatch(self, request):
lines = request.split("\n")
cmd = lines[0].lower()
tuples = []
for line in lines[1:]:
if line.strip() != "":
parts = line.split(':', 1)
tuples.append((parts[0], parts[1].lstrip()))
try:
return getattr(self, cmd)(tuples)
except AttributeError:
raise NotImplementedError
def handle(self):
# Read the request
startTime = time.clock()
buffer = ""
line = self.rfile.readline().strip("\n")
while line != "":
if buffer != "":
buffer += "\n"
buffer += line
line = self.rfile.readline().strip("\n")
# Dispatch it and handle the results
try:
writer = LDIFWriter(self.wfile)
for res in self.dispatch(buffer):
dn = res['entryDN'][0]
writer.unparse(dn, res)
# Blank line between results
self.wfile.write("\n")
self.wfile.write("RESULT\n")
self.wfile.write("code: 0\n")
self.wfile.write("info: Request handled in {0} seconds\n".format(time.clock()-startTime))
except NotImplementedError:
self.wfile.write("RESULT\n")
self.wfile.write("code: 53\n")
self.wfile.write("info: Not implemented here")
except LdapUnbindException:
pass
except Exception as e:
print("Exception occurred: {0}".format(e))
import traceback
print(traceback.format_exc(e))
self.wfile.write("RESULT\n")
self.wfile.write("code: 1\n")
# Because the message may have linebreaks, only send the exception type
self.wfile.write("info: Exception occurred: {0}\n".format(e.__class__.__name__))
class BtnLdapRequestHandler(LdapRequestHandler):
operAttrs = ['entryDN', 'entryUUID', 'hasSubordinates', 'numSubordinates']
objects = {
'dc=broadcasthe,dc=net': {
'entryDN': ['dc=broadcasthe,dc=net'],
'entryUUID': [str(uuid5(nsUuid, "root"))],
'objectClass': ['organization', 'dcObject'],
'dc': ['broadcasthe'],
'o': ['broadcasthe.net'],
'hasSubordinates': ["TRUE"],
'numSubordinates': [str(3)]
},
'ou=users,dc=broadcasthe,dc=net': {
'entryDN': ['ou=users,dc=broadcasthe,dc=net'],
'entryUUID': [str(uuid5(nsUuid, "users"))],
'objectClass': ['organizationalUnit'],
'ou': ['users'],
'description': ["BTN Users"],
'hasSubordinates': ["TRUE"],
'numSubordinates': [str(1)]
},
'ou=groups,dc=broadcasthe,dc=net': {
'entryDN': ['ou=groups,dc=broadcasthe,dc=net'],
'entryUUID': [str(uuid5(nsUuid, "groups"))],
'objectClass': ['organizationalUnit'],
'ou': ['groups'],
'description': ["BTN Groups"],
'hasSubordinates': ["TRUE"],
'numSubordinates': [str(1)]
},
'cn=admin,dc=broadcasthe,dc=net': {
'entryDN': ['cn=admin,dc=broadcasthe,dc=net'],
'entryUUID': [str(uuid5(nsUuid, "admin user"))],
'objectClass': ['organizationalRole', 'simpleSecurityObject'],
'cn': ['admin'],
'userPassword': ['{MD5}6u1Y9EQhPKdtMy7wbt3cdw=='],
'hasSubordinates': ["FALSE"],
'numSubordinates': [str(0)]
}
}
tree = {
'dc=broadcasthe,dc=net': {
'ou=users,dc=broadcasthe,dc=net': [],
'ou=groups,dc=broadcasthe,dc=net': [],
'cn=admin,dc=broadcasthe,dc=net': None
}
}
def find_subordinates(self, name):
def walk_tree(tree, part):
if part in tree:
return tree[part]
for key in tree:
if isinstance(tree[key], collections.Iterable):
res = walk_tree(tree[key], part)
if res != False:
return res
return False
resp = walk_tree(self.tree, name)
if hasattr(resp, 'keys'):
resp = resp.keys()
# Return the empty list when nothing's found
return resp or []
def fetchUsers(self, limit, attrs):
conn = sql.acquire()
with conn.cursor() as c:
# Unfortunately, oursql can't parameterize limits
c.execute("SELECT um.ID, um.Username, um.Email, um.NewPassHash, p.Name FROM users_main AS um JOIN permissions AS p ON um.PermissionID = p.ID LIMIT {0}".format(limit))
for uid, name, email, password, group in c:
# TODO: Again, we should never get a request for memberOf && !attrsonly but just in case...
yield {
'entryDN': ["cn={0},ou=users,dc=broadcasthe,dc=net".format(name)],
'entryUUID': [str(uuid5(nsUuid, "User: {0}".format(uid)))],
'objectClass': ['inetOrgPerson', 'simpleSecurityObject'],
'cn': [name],
'sn': [name],
'mail': [email],
'userPassword': ["{{BCRYPT}}{0}".format(password)],
'uid': [str(uid)],
'memberOf': ["cn={0},ou=groups,dc=broadcasthe,dc=net".format(group)],
'hasSubordinates': [str(False)],
'numSubordinates': [str(0)]
}
sql.release(conn)
def fetchGroups(self, limit, attrs):
conn = sql.acquire()
with conn.cursor() as c:
# Unfortunately, oursql can't parameterize limits
c.execute("SELECT id, name FROM permissions LIMIT {0}".format(limit))
for gid, name in c:
# TODO: Hopefully we never get a request for members && !attrsonly
yield {
'entryDN': ["cn={0},ou=groups,dc=broadcasthe,dc=net".format(name)],
'entryUUID': [str(uuid5(nsUuid, "Group: {0}".format(gid)))],
'objectClass': ['groupOfNames'],
'cn': [name],
'description': [name],
'hasSubordinates': [str(False)],
'numSubordinates': [str(0)]
}
sql.release(conn)
def fetchSingleUser(self, dn):
parts = dn.split(",", 1)
username = parts[0][3:]
conn = sql.acquire()
toRet = None
with conn.cursor() as c:
# Unfortunately, oursql can't parameterize limits
c.execute("SELECT um.ID, um.Username, um.Email, um.NewPassHash, p.Name FROM users_main AS um JOIN permissions AS p ON um.PermissionID = p.ID WHERE um.Username = ?", (username,))
uid, name, email, password, group = c.fetchone()
# TODO: Again, we should never get a request for memberOf && !attrsonly but just in case...
toRet = {
'entryDN': ["cn={0},ou=users,dc=broadcasthe,dc=net".format(name)],
'entryUUID': [str(uuid5(nsUuid, "User: {0}".format(uid)))],
'objectClass': ['inetOrgPerson', 'simpleSecurityObject'],
'cn': [name],
'sn': [name],
'mail': [email],
'userPassword': ["{{BCRYPT}}{0}".format(password)],
'uid': [str(uid)],
'memberOf': ["cn={0},ou=groups,dc=broadcasthe,dc=net".format(group)],
'hasSubordinates': ["FALSE"],
'numSubordinates': [str(0)]
}
sql.release(conn)
return toRet
def fetchSingleGroup(self, dn):
parts = dn.split(",", 1)
groupname = parts[0][3:]
conn = sql.acquire()
toRet = {
'entryDN': ["cn={0},ou=groups,dc=broadcasthe,dc=net".format(groupname)],
'entryUUID': None,
'objectClass': ['groupOfNames'],
'cn': [groupname],
'description': [groupname],
'member': [],
'hasSubordinates': ["FALSE"],
'numSubordinates': [str(0)]
}
with conn.cursor() as c:
# Unfortunately, oursql can't parameterize limits
c.execute("SELECT p.id, p.name, um.username FROM permissions AS p JOIN users_main AS um ON p.ID = um.PermissionID WHERE p.Name = ?", (groupname,))
for gid, name, username in c:
if not toRet['entryUUID']:
toRet['entryUUID'] = [str(uuid5(nsUuid, "Group: {0}".format(gid)))]
toRet['member'].append("cn={0},ou=users,dc=broadcasthe,dc=net".format(username))
sql.release(conn)
return toRet
def search(self, params):
# Fold the params into a dict
params = dict(params)
def filterAttrs(obj):
if params['attrs'] == "*":
return obj
else:
# Take a copy so as not to affect the original
obj = obj.copy()
if params['attrs'] == "+":
attrs = []
else:
attrs = params['attrs'].split()
# Not sure if this is kosher, but it seems to be what ADS wants
attrs.extend(self.operAttrs)
for key in obj.keys():
if not key in attrs:
del obj[key]
return obj
toRet = []
dn = params['base']
# Check our scope
if params['scope'] == "0":
# Fetching a single object
if dn in self.objects:
toRet.append(dn)
elif dn.endswith(",ou=users,dc=broadcasthe,dc=net"):
toRet.append(dn)
elif dn.endswith(",ou=groups,dc=broadcasthe,dc=net"):
toRet.append(dn)
elif params['scope'] == "1":
# Fetching immediate children
if dn == "ou=users,dc=broadcasthe,dc=net":
toRet.append("users")
elif dn == "ou=groups,dc=broadcasthe,dc=net":
toRet.append("groups")
else:
toRet.extend(self.find_subordinates(dn))
elif params['scope'] == "2":
# Complete subtree
if dn in self.objects:
toRet.append(dn)
toRet.extend(self.find_subordinates(dn))
if dn == "ou=users,dc=broadcasthe,dc=net":
# Special tokens, yay!
toRet.append("users")
elif dn == "ou=groups,dc=broadcasthe,dc=net":
toRet.append("groups")
elif dn.endswith(",ou=users,dc=broadcasthe,dc=net") or dn.endswith(",ou=groups,dc=broadcasthe,dc=net"):
toRet.append(dn)
limit = int(params['sizelimit'])
while limit and len(toRet):
next = toRet.pop()
if next == "users":
for obj in self.fetchUsers(limit, params['attrs']):
yield filterAttrs(obj)
limit -= 1
# Got to keep the count accurate
limit += 1
elif next == "groups":
for obj in self.fetchGroups(limit, params['attrs']):
yield filterAttrs(obj)
limit -= 1
# Got to keep the count accurate
limit += 1
elif next.endswith(",ou=users,dc=broadcasthe,dc=net"):
yield filterAttrs(self.fetchSingleUser(next))
elif next.endswith(",ou=groups,dc=broadcasthe,dc=net"):
yield filterAttrs(self.fetchSingleGroup(next))
else:
obj = self.objects[next]
# TODO: Fetch subordinate count for ou=groups|users
yield filterAttrs(obj)
limit -= 1
def bind(self, params):
return []
@command()
@option('--socket', '-s', default="/var/run/gazelle_ldap.sock", help="Socket location", show_default=True)
def serve(socket):
print("Starting socket server at {0}".format(socket))
if os.path.exists(socket):
print("Error: {0} already exists".format(socket))
return 1
server = ThreadedUnixStreamServer(socket, BtnLdapRequestHandler)
server_thread = threading.Thread(target=server.serve_forever)
server_thread.daemon = True
def term_handler(signum, frame):
server.shutdown()
signal.signal(signal.SIGTERM, term_handler)
server_thread.start()
try:
while server_thread.isAlive():
server_thread.join(3)
except:
server.shutdown()
finally:
os.unlink(socket)
return 0
if __name__ == "__main__":
sys.exit(serve())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment