Skip to content

Instantly share code, notes, and snippets.

@stevesimmons
Last active October 25, 2020 21:54
Show Gist options
  • Save stevesimmons/58b5e113a41c5c23775d17cc83929d88 to your computer and use it in GitHub Desktop.
Save stevesimmons/58b5e113a41c5c23775d17cc83929d88 to your computer and use it in GitHub Desktop.
Python bloomfilter accelerated by numba
import base64
import math
import random
import timeit
from typing import List, Optional
import numba
import numpy as np
class BloomFilter:
def __init__(self, seeds: Optional[List[int]] = None, arr: Optional[bytearray] = None,
size_bytes: Optional[int] = None, num_hashes: Optional[int] = None):
"""
Recreate a bloomfilter from a list of uint32 seeds and an array of bytes.
"""
if arr is not None:
self.arr = np.array(arr, dtype='uint8')
elif size_bytes is not None:
self.arr = bytearray(size_bytes)
else:
raise ValueError("Specify either arr or size_bytes")
if seeds is None and num_hashes is not None:
seeds = [random.getrandbits(32) for _ in range(num_hashes)]
#seeds = [random.randint(0, 0xFFFFFFFF) for _ in range(num_hashes)]
if seeds:
self.seeds = np.array(seeds, dtype='uint32')
else:
raise ValueError("Specify either seeds or num_hashes")
def capacity(self, false_positive_rate: float) -> int:
"Max number of items in the filter to be within a designated false positive rate"
bits_in_array = len(self.arr) * 8
return int(bits_in_array * (math.log(2) ** 2) / abs(math.log(false_positive_rate)))
def false_positive_rate(self, items_in_array: int) -> float:
#bits_in_array = len(self.arr) * 8
#math.exp(num_hashes)/math.exp(2) = max(math.floor(math.log2(1 / error_rate)), 1)
#return math.exp(-1 * bits_in_array / items_in_array * log(math.pow(2, math.log(2))))
return 0.0
def __repr__(self):
return f"<BloomFilter size {len(self.arr)} with {len(self.seeds)} seeds. {self.count_bits()} bits set>"
@property
def base64(self) -> str:
return base64.b64encode(self.arr)
def add_string(self, s: str):
"Add a string to the bloomfilter"
return self._add(self.arr, s.encode(), self.seeds)
def add_bytes(self, b: bytes):
"Add a bytes string to the bloomfilter array"
return self._add(self.arr, b, self.seeds)
def add_uuid(self, u: Union[str, uuid.UUID]):
"Add a UUID string of format 'xxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx' to the bloomfilter array"
# Simple string conversion is 2.5x faster than via UUID object
#return self._add(self.arr, uuid.UUID(uuid_string).bytes, self.seeds)
return self._add(self.arr, str(u).lower().encode(), self.seeds)
def has_string(self, s: str) -> bool:
"True if the given string is in the bloomfilter"
return self._contains(self.arr, s.encode(), self.seeds)
def has_bytes(self, b: bytes) -> bool:
"True if the given bytes object is in the bloomfilter"
return self._contains(self.arr, b, self.seeds)
def has_uuid(self, u: Union[str, uuid.UUID]) -> bool:
"True if the given UUID string (independent of upper/lower case) is in the bloomfilter"
# Simple string conversion is 2.5x faster than via UUID object
#return self._contains(self.arr, uuid.UUID(u).bytes, self.seeds)
return self._contains(self.arr, str(u).lower().encode(), self.seeds)
add = add_string
__contains__ = has_string
def count_bits(self) -> int:
"Number of bits in the array that are set"
byte_counts = np.bincount(self.arr, None, 256)
return self._countbits(byte_counts)
def clear(self):
"Empty the bloomfilter array"
self.arr.fill(0)
def similar_copy(self):
"A new Bloomfilter with the same size array and same seeds"
return self.__class__(self.seeds.copy(), np.zero_like(self.arr))
def copy(self):
"A copy of this Bloomfilter with the array initialised with the same values"
return self.__class__(self.seeds.copy(), self.arr.copy())
@staticmethod
@numba.njit
def _add(arr, key, seeds):
num_bits = len(arr) * 8
for seed in seeds:
loc = murmurhash(key, seed) % num_bits
offset, shift = divmod(loc, 8)
arr[offset] |= (1 << shift)
@staticmethod
@numba.njit
def _contains(arr, key, seeds):
num_bits = len(arr) * 8
for seed in seeds:
loc = murmurhash(key, seed) % num_bits
offset, shift = divmod(loc, 8)
if not (arr[offset] & (1 << shift)):
return False
return True
@staticmethod
@numba.njit
def _countbits(byte_counts) -> int:
num_bits = 0
for i, cnt in enumerate(byte_counts):
if cnt > 0:
while i:
i &= i - 1
num_bits += cnt
return num_bits
@numba.njit
def murmurhash(key, seed) -> int:
"""
Numba-accelerated 32-bit murmurhash.
"""
length = len(key)
n, t = divmod(length, 4)
h = seed
c1 = 0xcc9e2d51
c2 = 0x1b873593
# Process whole blocks of 4 bytes
for i in range(n):
k1 = (key[4*i] << 24) + (key[4*i + 1] << 16) + (key[4*i + 2] << 8) + key[4*i + 3]
k1 = (k1 * c1) & 0xFFFFFFFF
k1 = ((k1 << 15) | (k1 >> 17)) & 0xFFFFFFFF # ROTL32
h ^= (k1 * c2) & 0xFFFFFFFF
h = ((h << 13) | (h >> 19)) & 0xFFFFFFFF # ROTL32
h = (h * 5 + 0xe6546b64) & 0xFFFFFFFF
# Process tail of 1-3 bytes if present
if t > 0:
k1 = (key[4*n] << 16)
if t > 1:
k1 += key[4*n + 1] << 8
if t > 2:
k1 += key[4*n + 2]
k1 = (k1 * c1) & 0xFFFFFFFF
k1 = ((k1 << 15) | (k1 >> 17)) & 0xFFFFFFFF # ROTL32
k1 = (k1 * c2) & 0xFFFFFFFF
h ^= k1
h ^= length # Include length to give different values for 1-3 tails of \0 bytes
# Finalise by mixing the bits
x = h
x ^= (x >> 16)
x = (x * 0x85ebca6b) & 0xFFFFFFFF
x ^= (x >> 13)
x = (x * 0xc2b2ae35) & 0xFFFFFFFF
x ^= (x >> 16)
return x
def test_bloomfilter():
s = 'hello'
b = b'hello'
b2 = b'hello1'
u = '12345678-1234-1234-1234-abcdefABCDEF'
is_u = '12345678-1234-1234-1234-aBcDeFaBcDeF'
not_u = '02345678-1234-1234-1234-abcdefABCDEF'
bf = BloomFilter(num_hashes=6, size_bytes=1000)
print(repr(bf))
bf.add_string(s)
bf.add_bytes(b)
bf.add_uuid(u)
bf.add_uuid(is_u)
res = bf.has_string(s), bf.has_bytes(b), bf.has_bytes(b2), bf.has_uuid(u), bf.has_uuid(is_u), bf.has_uuid(not_u)
assert res == (True, True, False, True, True, False)
bf = BloomFilter(num_hashes=6, size_bytes=50)
bf.add('hello')
print(bf.base64)
statements = [
'bf.add_string(s)',
'bf.add_bytes(b)',
'bf.add_uuid(u)',
'bf.has_string(s)',
'bf.has_bytes(b)',
'bf.has_uuid(u)',
'bf.has_uuid(is_u)',
'bf.has_uuid(not_u)',
]
# Timing of common operations
# Approx 1/3 speed of pybloomfiltermmap3, which is written in C and Cython.
# Which shows how good numba is!
bf = BloomFilter(num_hashes=6, size_bytes=100000)
variables = dict(bf=bf, s=s, b=b, u=u, is_u=is_u, not_u=not_u)
for stmt in statements:
timeit.timeit(stmt, number=10, globals=variables) # Don't count JIT time first time through
secs = timeit.timeit(stmt, number=10000, globals=variables)
print(f"{stmt:30} -> {secs * 100:0.3f} µs")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment