Skip to content

Instantly share code, notes, and snippets.

@dbro
Created April 1, 2014 18:57
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dbro/9920666 to your computer and use it in GitHub Desktop.
Save dbro/9920666 to your computer and use it in GitHub Desktop.
Implementation of Hyper Log-Log probabilistic counting methods in lua inside redis, via python
# Lua routines for use inside the Redis datastore
# Hyperloglog cardinality estimation
# ported from http://stackoverflow.com/questions/5990713/loglog-algorithm-for-counting-of-large-cardinalities
#
# Dan Brown, 2012. https://github.com/dbro
#
# note that lua needs to have the bitlib and murmur3 modules built in, and loaded by redis
#
# suitable for counting unique items from 0 to billions
# choose a k value to balance storage and precision objectives
#
# init complexity proportional to number of registers ( = 2^k ~ 1/p^2)
# update complexity is constant
# count complexity is constant for a single set (due to accumulated stats)
# count complexity for union of multiple sets is proportional to N/p^2
# there is a half-speed slowdown when cardinality is less than registers*2.5
#
# k std_error storage bits registers redis_strlen (bytes)
# 4 0.26 80 16 10
# 6 0.13 320 64 40
# 8 0.065 1280 256 160
# 10 0.0325 5120 1024 640
# 12 0.01625 20480 4096 2560
# 14 0.008125 81920 16384 10240
# 16 0.0040625 327680 65536 40960
#
import redis
# eval """__the_lines_below__""" 1 set_name std_error
# eval """__the_lines_below__""" 1 set_name std_error expires_seconds
init_script = """
local std_error = tonumber(ARGV[1]);
assert((0 < std_error) and (std_error < 1),
'ERROR. standard error of ' .. std_error .. ' is outside acceptable range of 0 < s < 1');
local k = math.ceil(math.log(math.pow(1.03896 / std_error, 2)) / math.log(2))
k = (k < 4) and 4 or (k > 16) and 16 or k;
local m = math.pow(2, k);
local bitset_key = 'bitset:' .. KEYS[1];
redis.call('set', bitset_key, '');
redis.call('setbit', bitset_key, 8 * math.ceil(m * 0.625) - 1, 0);
local metadata_key = 'metadata:' .. KEYS[1];
redis.call('hmset', metadata_key,
'k', k, -- constant
'm', m, -- constant
'total_c', m, -- for use in single-set count optimization
'total_v', m, -- for use in single-set count optimization
'update_count', 0); -- not necessary, but nice to have
local expires_seconds = tonumber(ARGV[2])
if expires_seconds > 0 then
redis.call('expire', bitset_key, math.ceil(expires_seconds));
redis.call('expire', metadata_key, math.ceil(expires_seconds));
end
return string.format("%.4f", 1.03896 / math.sqrt(m));
""" # init_script
# update script accepts text strings as input
# this routine always hashes the input
# eval """__the_lines_below__""" 1 set_name item
# also the function can accept multiple items
# eval """__the_lines_below__""" 1 set_name item1 item2 ... itemN
# remember, Redis arrays are 0-offset
update_script = """
local band = bit.band;
local rshift = bit.rshift;
local lshift = bit.lshift;
local item_count = #ARGV;
local bitset_key = 'bitset:' .. KEYS[1];
local set_exists = redis.call('exists', bitset_key);
if not set_exists then return 0; end
local metadata_key = 'metadata:' .. KEYS[1];
--local metadata = redis.call('hgetall', metadata_key);
local k = tonumber(redis.call('hget', metadata_key, 'k'));
local k_comp = 32 - k;
-- use lookup tables to speed things up and check for expected data
-- we could make these lookup tables constant to further improve performance
local powersofhalf = {
1/2,1/4,1/8,1/16,1/32,1/64,1/128,1/256,1/512,1/1024,1/2048,1/4096,1/8192,
1/16384,1/32768,1/65536,1/131072,1/262144,1/524288,1/1048576,1/2097152,
1/4194304,1/8388608,1/16777216,1/33554432,1/67108864,1/134217728,
1/268435456,1/536870912,1/1073741824,1/2147483648,1/4294967296};
powersofhalf[0] = 1;
--local p = 1;
--for i=0,32 do -- need to go up to k_comp, which will never be > 32
-- powersofhalf[i] = 1 / p;
-- p = p*2;
--end
-- functions to unpack 5-bit integer registers from the bytestring,
-- in chunks of 40 bits (5 bytes is exactly 8 registers)
-- all bitset sizes are multiples of 40 bits, so this is clean
local function assemble_bytestring(r)
-- input should be an array of 8 registers, each with 5 bits
-- returns a 5-byte string
local b1 = string.char(lshift(r[1],3) + rshift(r[2],2));
local b2 = string.char(band(0xff, lshift(r[2],6) + lshift(r[3],1) + rshift(r[4],4)));
local b3 = string.char(band(0xff, lshift(r[4],4) + rshift(r[5],1)));
local b4 = string.char(band(0xff, lshift(r[5],7) + lshift(r[6],2) + rshift(r[7],3)));
local b5 = string.char(band(0xff, lshift(r[7],5) + r[8]));
return b1 .. b2 .. b3 .. b4 .. b5;
end
local function unpack_bytestring(bytestring)
-- bytestring should have length 40 bits / 5 bytes
-- returns an array of 8 registers, each with 5 bits
local b1 = bytestring:byte(1);
local b2 = bytestring:byte(2);
local b3 = bytestring:byte(3);
local b4 = bytestring:byte(4);
local b5 = bytestring:byte(5);
local r1 = rshift(band(b1, 0xf8), 3);
local r2 = lshift(band(b1, 0x07), 2) + rshift(band(b2, 0xc0), 6);
local r3 = rshift(band(b2, 0x3e), 1);
local r4 = lshift(band(b2, 0x01), 4) + rshift(band(b3, 0xf0), 4);
local r5 = lshift(band(b3, 0x0f), 1) + rshift(band(b4, 0x80), 7);
local r6 = rshift(band(b4, 0x7c), 2);
local r7 = lshift(band(b4, 0x03), 3) + rshift(band(b5, 0xe0), 5);
local r8 = band(b5, 0x1f);
return {r1,r2,r3,r4,r5,r6,r7,r8};
end
local function rank(hash)
-- position of the leftmost bit of the hash
-- assumes a 32-bit hash
local local_hash = lshift(hash, k);
local r = 1;
while (band(local_hash, 0x80000000) == 0 and r <= k_comp) do
r = r + 1;
local_hash = lshift(local_hash, 1);
end
return r;
end
local v_increment = 0;
local c_increment = 0.0;
for i=1,item_count do
-- old method that assumed the input was already a valid 32-bit hash
-- in the case of the input already being a hash, we can work with that directly.
--local item_hash = bit.tobit('0x' .. ARGV[1]); -- for hash hex digest input ("abcdef12")
-- murmurhash appears to be very lightweight, so we can hash everything
-- that's coming in without much performance impact.
local item_hash = murmur3.murmur3(ARGV[i]);
local j = rshift(item_hash, k_comp);
-- read the whole chunk that contains the j'th register
local firstbit = j * 5;
local blockstart_bit = math.floor(firstbit / 40) * 40;
local blockstart_byte = blockstart_bit / 8;
local bytestring = redis.call('getrange', bitset_key,
blockstart_byte, blockstart_byte + 4);
local registers = unpack_bytestring(bytestring);
-- read the value at the j'th register
local register_index = (j % 8) + 1; -- lua is 1-indexed
local existing_value = registers[register_index];
local current_value = rank(item_hash);
--print("update hash: " .. bitset_key .. "\tk = " .. k .. "\tj = " .. j .. "\thash = " .. bit.tohex(item_hash) .. "\trank = " .. current_value);
--print("\texisting_value = " .. existing_value .. "\tcurrent_value = " .. current_value);
-- set the new register value to be the maximum
if current_value > existing_value then
if existing_value == 0 then
-- one fewer zero valued register
--print('decrementing v for ' .. metadata_key);
v_increment = v_increment - 1;
end
c_increment = c_increment + powersofhalf[current_value] - powersofhalf[existing_value];
--print('incrementing c of ' .. metadata_key .. ' by ' .. c_increment);
-- write the new data to the register
registers[register_index] = current_value;
redis.call('setrange', bitset_key, blockstart_byte,
assemble_bytestring(registers));
--print("\tchanged register to " .. current_value);
end
end
-- update metadata for cumulative stats
redis.call('hincrby', metadata_key, 'total_v', v_increment);
redis.call('hincrbyfloat', metadata_key, 'total_c', c_increment);
redis.call('hincrby', metadata_key, 'update_count', item_count); -- track number of updates
return item_count;
""" # update_script
# this is a blocking routine that can take a while,
# depending on the bitset size, which depends on the precision
# (not the cardinality estimate or number of updates)
# eval """__the_lines_below__""" 1 set_name
# to estimate cardinality among an aggregation of sets, provide multiple
# arguments. Note that each bit set sizes must be the same as the others.
# eval """__the_lines_below__""" N set_name_1 set_name_2 ... set_name_N
# to materialize the united set as a new set, include the name as an extra argument:
# eval """__the_lines_below__""" N set_name_1 set_name_2 ... set_name_N union_set_name union_set_expires_seconds
count_script = """
local keys_count = #KEYS;
local bitset_key = 'bitset:' .. KEYS[1];
local metadata_key = 'metadata:' .. KEYS[1];
local metadata_list = redis.call('hgetall', metadata_key);
local metadata = {};
for i=1,#metadata_list,2 do
metadata[metadata_list[i]] = metadata_list[i+1];
end
local k = tonumber(metadata['k']);
local m = tonumber(metadata['m']);
local update_count = tonumber(metadata['update_count']);
local materialize_union = (keys_count > 1) and ARGV[1] or nil;
--print('count : k = ' .. k .. '\tm = ' .. m .. '\tupdate_count = ' .. update_count);
local function precision(set_k)
return 1.04 / math.sqrt(math.pow(2, set_k));
end
local c = 0.0;
local V = 0; -- used in case of small-cardinality adjustment
if keys_count == 1 then
c = tonumber(metadata['total_c']);
V = tonumber(metadata['total_v']);
--print('one bitset optimization. c = ' .. c .. '\tV = ' .. V);
else -- union of multiple sets
local band = bit.band;
local rshift = bit.rshift;
local lshift = bit.lshift;
-- check that all bit sets are the same size
local addl_metadata_key;
local addl_metadata_list;
local addl_metadata;
local addl_k;
for i=2,keys_count do
addl_metadata_key = 'metadata:' .. KEYS[i];
addl_metadata_list = redis.call('hgetall', addl_metadata_key);
addl_metadata = {};
for j=1,#addl_metadata_list,2 do
addl_metadata[addl_metadata_list[j]] = addl_metadata_list[j+1];
end
addl_k = tonumber(addl_metadata['k']);
assert(addl_k == k,
'ERROR. Aggregation among sets requires all sets to use the same precision. Original set is ' .. precision(k) .. ' and set # ' .. i .. ' ( ' .. KEYS[i] .. ' ) is\t' .. precision(addl_k));
update_count = update_count + tonumber(addl_metadata['update_count']);
end
-- collect cardinality information from all sets
local bitset = redis.call('get', bitset_key); -- get the whole first bitset
-- functions to unpack 5-bit integer registers from the bytestring,
-- in chunks of 40 bits (5 bytes is exactly 8 registers)
-- all bitset sizes are multiples of 40 bits, so this is clean
local function assemble_bytestring(r)
-- input should be an array of 8 registers, each with 5 bits
-- returns a 5-byte string
local b1 = string.char(lshift(r[1],3) + rshift(r[2],2));
local b2 = string.char(band(0xff, lshift(r[2],6) + lshift(r[3],1) + rshift(r[4],4)));
local b3 = string.char(band(0xff, lshift(r[4],4) + rshift(r[5],1)));
local b4 = string.char(band(0xff, lshift(r[5],7) + lshift(r[6],2) + rshift(r[7],3)));
local b5 = string.char(band(0xff, lshift(r[7],5) + r[8]));
return b1 .. b2 .. b3 .. b4 .. b5;
end
local function unpack_bytestring(bytestring)
-- bytestring should have length 40 bits / 5 bytes
-- returns an array of 8 registers, each with 5 bits
--print('bytestring (' .. bytestring:len() .. ' bytes)' .. bytestring);
local b1 = bytestring:byte(1);
local b2 = bytestring:byte(2);
local b3 = bytestring:byte(3);
local b4 = bytestring:byte(4);
local b5 = bytestring:byte(5);
local r1 = rshift(band(b1, 0xf8), 3);
local r2 = lshift(band(b1, 0x07), 2) + rshift(band(b2, 0xc0), 6);
local r3 = rshift(band(b2, 0x3e), 1);
local r4 = lshift(band(b2, 0x01), 4) + rshift(band(b3, 0xf0), 4);
local r5 = lshift(band(b3, 0x0f), 1) + rshift(band(b4, 0x80), 7);
local r6 = rshift(band(b4, 0x7c), 2);
local r7 = lshift(band(b4, 0x03), 3) + rshift(band(b5, 0xe0), 5);
local r8 = band(b5, 0x1f);
return {r1,r2,r3,r4,r5,r6,r7,r8};
end
local bytestring; -- 5 bytes long. for the original set
local registers; -- array of 8 registers, 5 bytes each
local addl_bytestring; -- for the additional set
local addl_registers; -- for the additional set
local addl_bitsets = {};
local addl_bitset_key;
local addl_update_count_key;
for i=2,keys_count do
-- read all of the sets to be united ahead of time
-- this assumes there are a reasonable number of sets (<100?)
addl_bitset_key = 'bitset:' .. KEYS[i];
addl_bitsets[i] = redis.call('get', addl_bitset_key);
end
-- use a lookup table to speed things up and check for expected data
local p = 1;
local powersofhalf = {};
for i=0,32 do -- need to go up to k_comp, which will never be > 32
powersofhalf[i] = 1 / p;
p = p*2;
end
local union_chunks = {};
local bitset_length_bytes = m * 0.625;
for i=1,bitset_length_bytes - 1,5 do
-- read a chunk of 40 bits of the original set
bytestring = bitset:sub(i,i+4);
registers = unpack_bytestring(bytestring);
-- get maximum of all the united sets
for j=2,keys_count do
addl_bytestring = addl_bitsets[j]:sub(i,i+4);
addl_registers = unpack_bytestring(addl_bytestring);
for k=1,8 do
registers[k] = math.max(registers[k], addl_registers[k]);
end
end
-- update raw estimate calculation
for j=1,8 do
c = c + powersofhalf[registers[j]];
if registers[j] == 0 then V = V+1; end -- used for small-cardinality
end
if materialize_union ~= nil then
-- append to the accumulating bitset table
table.insert(union_chunks, assemble_bytestring(registers));
end
end
--print('union of ' .. keys_count .. ' sets. c = ' .. c .. '\tV = ' .. V);
if materialize_union ~= nil then
-- save results of union into a new set
local union_bitset = table.concat(union_chunks);
local union_bitset_key = 'bitset:' .. materialize_union;
redis.call('set', union_bitset_key, union_bitset);
local union_metadata_key = 'metadata:' .. materialize_union;
redis.call('hmset', union_metadata_key,
'k', k, -- constant
'm', m, -- constant
'total_c', c, -- for use in single-set count optimization
'total_v', V, -- for use in single-set count optimization
'update_count', update_count); -- not necessary, but nice to have
local expires_seconds = tonumber(ARGV[2]);
if expires_seconds > 0 then
redis.call('expire', union_bitset_key, math.ceil(expires_seconds));
redis.call('expire', union_metadata_key, math.ceil(expires_seconds));
end
-- point to union set from each source set so updates will propagate
for i=1,keys_count do
local subscribers_key = 'subscribers:' .. KEYS[i];
redis.call('sadd', subscribers_key, materialize_union);
local bitset_key = 'bitset:' .. KEYS[i];
local source_ttl = redis.call('ttl', bitset_key);
if source_ttl > 0 then
redis.call('expire', subscribers_key, source_ttl);
end
end
end
end
-- calculate raw estimate
local alpha_m = (m == 16) and 0.673
or (m == 32) and 0.697
or (m == 64) and 0.709
or 0.7213 / (1 + 1.079 / m); -- for m >= 128
local E = alpha_m * m * m / c;
-- make corrections if needed
local pow_2_32 = 4294967296;
if (E <= (5/2 * m)) then
--print("\tsmall case correction: E_orig = " .. E .. "\tV = " .. V);
if (V > 0) then E = m * math.log(m/V); end
elseif (E > 1/30 * pow_2_32) then
--print("\tlarge case correction: E_orig = " .. E);
E = -1 * pow_2_32 * math.log(1 - E / pow_2_32);
end
local function round(x)
return math.floor(x + 0.5);
end
local prec = precision(k);
return {
"count", round(E),
"standard_error", string.format("%.4f", prec),
"range_95_pct", string.format("%d - %d", round(E*(1 - 2*prec)), round(E*(1 + 2*prec))),
"update_count", update_count,
"bytes_scanned", keys_count * m * 0.625,
"new_set_created", materialize_union
};
""" # count_script
class UniqueCounter:
def __init__(self):
self._r = redis.Redis()
def init(self, set_name, standard_error=0.01, expires=0):
# expires is time in seconds into the future
# returns actual standard_error, which may be more precise than requested
return self._r.eval(init_script, 1, set_name, standard_error, expires)
def _update(self, set_name, *items):
# returns number of items updated (same as input)
# recursively propagates updates through the aggregation
# tree of materialized union sets
# beware! no checks are made to avoid circular references
subscribers = list(self._r.smembers('subscribers:%s' % set_name))
return self._r.eval(update_script, 1, set_name, *items) + \
sum([self._update(x, *items) for x in subscribers])
def _exists(self, set_name):
return self._r.exists('bitset:' + set_name)
def update(self, set_name, *items):
if not self._exists(set_name):
self.init(set_name)
return self._update(set_name, *items)
def count(self, *set_names, **kwargs):
for s in set_names:
if not self._exists(s):
raise NameError('set named %s not found' % s)
if 'save' in kwargs and kwargs['save']:
if 'expires' in kwargs and kwargs['expires'] > 0:
raw = self._r.eval(count_script, len(set_names),
*(list(set_names) + [kwargs['save'], kwargs['expires']]))
else:
raw = self._r.eval(count_script, len(set_names),
*(list(set_names) + [kwargs['save'], 0]))
else:
raw = self._r.eval(count_script, len(set_names), *set_names)
return dict([(k, v) for k,v in zip (raw[::2], raw[1::2])])
def test_init(uc, s):
uc.init(s)
def test_update(random, uc, s):
uc.update(s, 'abc123xyz_%d' % random.randint(0,0xfffffff))
def test_count(uc, *s, **k):
uc.count(*s, **k)
def test_speed():
# some results based on running speed tests on a lightweight netbook:
# standard error = 0.01, 10kb of registers
# < 1 ms per init()
# < 1 ms per update()
# < 1 ms per count() for single set
# ~30 ms per set for count() of union of sets
import timeit
icount = 500
t = timeit.Timer("test_init(uc, setname)", "from __main__ import test_init, UniqueCounter; uc = UniqueCounter(); setname='test1'")
print "init() test:\t(%d iterations)\t%.6f / run\t%.3f total time" % (icount, t.timeit(number=icount) / float(icount), t.timeit(number=icount))
t = timeit.Timer("test_update(random, uc, setname)", "import random; from __main__ import test_update, UniqueCounter; uc = UniqueCounter(); setname='test1'; uc.init(setname)")
print "update() test:\t(%d iterations)\t%.6f / run\t%.3f total time" % (icount, t.timeit(number=icount) / float(icount), t.timeit(number=icount))
t = timeit.Timer("test_count(uc, setname)", "from __main__ import test_count, UniqueCounter; uc = UniqueCounter(); setname='test1'; uc.init(setname)")
print "count() test:\t(%d iterations)\t%.6f / run\t%.3f total time" % (icount, t.timeit(number=icount) / float(icount), t.timeit(number=icount))
icount = 10
scount = 2
setname = 'test1'
setnames = repr([setname] * scount)[1:-1]
t = timeit.Timer("test_count(uc, %s)" % setnames, "from __main__ import test_count, UniqueCounter; uc = UniqueCounter(); uc.init(%s)" % repr(setname))
testtime = t.timeit(number=icount)
print "count(union) test:\t(%d iterations, %d sets)\t%.6f / run\t%.6f / set\t%.3f total time" % (icount, scount, testtime / float(icount), testtime / float(icount) / float(scount), testtime)
t = timeit.Timer("test_count(uc, %s, save=set_union_name)" % setnames, "from __main__ import test_count, UniqueCounter; uc = UniqueCounter(); uc.init(%s); set_union_name='testunion1'" % repr(setname))
testtime = t.timeit(number=icount)
print "count(union) save:\t(%d iterations, %d sets)\t%.6f / run\t%.6f / set\t%.3f total time" % (icount, scount, testtime / float(icount), testtime / float(icount) / float(scount), testtime)
def test_accuracy(enable_set=True, enable_bit=True, enable_hll=True):
# compares the accuracy of 3 methods for counting uniques
# some conclusions based on usage of this test:
# ACCURACY
# set and bitarray are exact
# hyperloglog precision claims appear to be valid
# PERFORMANCE
# updates to set and bitarray are very fast compared to hll
# approx. 1.5 seconds for 100,000 updates to a set and/or bitarray
# approx. 90 seconds for 100,000 updates to a hyperloglog counter
# approx. 9 seconds for 10,000 updates to a hyperloglog counter
# SCALE
# sets need about 10 bytes per unique item
# bitarray needs about 0.1 bytes per possible unique item
# a 1% std error hll set needs 10kb always
# so when counting less than 100,000 possible unique items, prefer a bitarray
# when counting less than 1,000 unique items, prefer a set
# when counting more than 1,000 unique items,
# choose an option with an appropriate balance between
# storage, performance, and precision
import sys
from bitarray import bitarray
from random import randint
from math import ceil
from time import time
icount = 10000 # number of updates
collisionfactor = 0.1 # portion of updates expected to collide
randomceiling = icount * (1 - collisionfactor)
testbitarray = bitarray(icount)
testbitarray.setall(0)
testset = set()
testuc = UniqueCounter()
testsetname = 'test_accuracy'
testuc.init(testsetname, 0.01) # 1% standard error
print "accuracy test:\t%d updates with %.2f collision factor" % (icount, collisionfactor)
print "intermediate counts:\n\ti\tset\tbit\thll\terr\t%err"
starttime = time()
updatetime = starttime + 5
records = []
record_interval = icount / 20
for i in xrange(icount):
if (i % record_interval == 0) or (i % 1000 == 0 and time() > updatetime):
r = [i, len(testset), testbitarray.count(), testuc.count(testsetname)['count']]
r.append(r[3] - r[2])
r.append(0 if r[2] == 0 else r[4] / float(r[2]))
print "\t%d\t%d\t%d\t%d\t%d\t%+0.4f" % tuple(r)
updatetime = time() + 5
itemnumber = randint(0, randomceiling)
itemstring = "item number %d" % itemnumber
if enable_bit:
testbitarray[itemnumber] = 1
if enable_set:
testset.add(itemstring)
if enable_hll:
testuc.update(testsetname, itemstring)
r = [icount, len(testset), testbitarray.count(), testuc.count(testsetname)['count']]
r.append(r[3] - r[2])
r.append(0 if r[2] == 0 else r[4] / float(r[2]))
print "\t%d\t%d\t%d\t%d\t%d\t%+0.4f" % tuple(r)
print
print "\tset:\t\tcount =\t%d\tbytes =\t%d" % (len(testset), sys.getsizeof(testset))
print "\tbitarray:\tcount =\t%d\tbytes =\t%d" % (testbitarray.count(), ceil(testbitarray.length() / 8.0))
uc_stats = testuc.count(testsetname)
print "\thll:\t\tcount =\t%d\tbytes =\t%d" % (uc_stats['count'], uc_stats['bytes_scanned'])
actual_count = testbitarray.count()
hll_error = uc_stats['count'] - actual_count
print "\t\t\t95%% range = [ %s ]" % uc_stats['range_95_pct']
print "\t\t\terror =\t%d\t(%.1f * std_dev)" % \
(hll_error, abs(hll_error / float(actual_count) / float(uc_stats['standard_error'])))
print
print "\tdone in %.2f seconds\t(%.6f sec/update)" % (time() - starttime, (time() - starttime) / float(icount))
if __name__ == "__main__":
test_speed()
test_accuracy()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment