Skip to content

Instantly share code, notes, and snippets.

@xtao
Forked from tarekziade/distribution.py
Created June 4, 2016 05:39
Show Gist options
  • Save xtao/b2e918c7d0dc092bbab24637ef7ae6bc to your computer and use it in GitHub Desktop.
Save xtao/b2e918c7d0dc092bbab24637ef7ae6bc to your computer and use it in GitHub Desktop.
Consistent Distribution of users across servers
""" Consistent load-balancing.
We have a few servers and we want a load-balancer to
distribute incoming requests across them in a deterministic
and consistent way - without keeping any counter to make the
decision.
Removing a backend server should not impact users on other
servers.
Adding a backend will generate the redistribution of
users across other servers.
The goal is to come up with the best algorithm
for 1M users across 5 servers. Speed is a bonus.
Two known techniques:
- RendezVous : https://en.wikipedia.org/wiki/Rendezvous_hashing#Comparison_With_Consistent_Hashing
- Consistent Hashing: https://en.wikipedia.org/wiki/Consistent_hashing
Consistent Hashing implementation inspired by:
http://techspot.zzzeek.org/2012/07/07/the-absolutely-simplest-consistent-hashing-example
Also, good read on hashes:
http://programmers.stackexchange.com/questions/49550/which-hashing-algorithm-is-best-for-uniqueness-and-speed/145633#145633
"""
import hashlib
import bisect
from collections import defaultdict
import binascii
import time
from functools import wraps
class CollisionError(Exception):
pass
_collisions = {}
def catch_collision(func):
@wraps(func)
def _catch(key):
res = func(key)
if res in _collisions and key != _collisions[res]:
raise CollisionError('%s and %s with %s' % (key, _collisions[res],
func))
_collisions[res] = key
return res
return _catch
@catch_collision
def fnv32a(key):
hval = 0x811c9dc5
fnv_32_prime = 0x01000193
uint32_max = 2 ** 32
for s in key:
hval = hval ^ ord(s)
hval = (hval * fnv_32_prime) % uint32_max
return hval
@catch_collision
def sha512(key):
return long(hashlib.sha512(key).hexdigest(), 16)
@catch_collision
def sha256(key):
return long(hashlib.sha256(key).hexdigest(), 16)
@catch_collision
def md5(key):
return long(hashlib.md5(key).hexdigest(), 16)
class RendezVous(object):
def __init__(self, ips=None, hash=md5):
if ips is None:
ips = []
self.ips = ips
self._hash = hash
def __str__(self):
return '<RendezVous with %s hash>' % self._hash
def add(self, ip):
self.ips.append(ip)
def remove(self, ip):
self.ips.remove(ip)
def select(self, key):
high_score = -1
winner = None
for ip in self.ips:
score = self._hash("%s-%s" % (str(ip), str(key)))
if score > high_score:
high_score, winner = score, ip
elif score == high_score:
high_score, winner = score, max(str(ip), str(winner))
return winner
def _repl(name, index):
return '%s:%d' % (name, index)
class ConsistentHashing(object):
def __init__(self, ips=[], replicas=200, hash=md5):
self._ips = {}
self._hashed_ips = []
self.replicas = replicas
self._hash = hash
for ip in ips:
self.add(ip)
def __str__(self):
return '<ConsistentHashing with %s hash>' % self._hash
def add(self, ip):
for i in range(self.replicas):
sip = _repl(ip, i)
hashed = self._hash(sip)
self._ips[hashed] = sip
bisect.insort(self._hashed_ips, hashed)
def remove(self, ip):
for i in range(self.replicas):
sip = _repl(ip, i)
hashed = self._hash(sip)
del self._ips[hashed]
index = bisect.bisect_left(self._hashed_ips, hashed)
del self._hashed_ips[index]
def select(self, username):
hashed = self._hash(username)
start = bisect.bisect(self._hashed_ips, hashed,
hi=len(self._hashed_ips)-1)
return self._ips[self._hashed_ips[start]].split(':')[0]
NUM_USERS = 1000000
def run_test(servers, users):
selection = defaultdict(list)
for user in users:
user_db = servers.select(user)
selection[user_db].append(user)
print '===='
print('Distribution')
smallest = NUM_USERS + 1
biggest = 0
for db in selection:
size = len(selection[db])
if size < smallest:
smallest = size
if size > biggest:
biggest = size
print('%d users in %s' % (size, db))
print('span: %d' % (biggest - smallest))
# removing server 2 and 4
servers.remove('postgres2')
print '===='
selection = defaultdict(list)
for user in users:
user_db = servers.select(user)
selection[user_db].append(user)
smallest = NUM_USERS + 1
biggest = 0
for i, db in enumerate(selection):
size = len(selection[db])
if size < smallest:
smallest = size
if size > biggest:
biggest = size
print('%d users in %s' % (size, db))
print('span: %d' % (biggest - smallest))
if __name__ == '__main__':
users = ['%06d' % i for i in range(NUM_USERS)]
servers = ['postgres5', 'postgres2', 'postgres3', 'postgres4',
'postgres1']
for klass in (ConsistentHashing, RendezVous):
for hash in (md5, sha256, binascii.crc32, fnv32a, sha512):
try:
cluster = klass(list(servers), hash=hash)
except CollisionError:
print('Collision error with hash %s' % hash)
continue
print(cluster)
start = time.time()
try:
run_test(cluster, users)
except CollisionError:
print('Collision error..')
print('Took %d seconds' % (time.time() - start))
print
print
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment