Skip to content

Instantly share code, notes, and snippets.

@rjzak
Created December 25, 2019 02:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rjzak/543a7228a3aa505d786cde9c64014c4b to your computer and use it in GitHub Desktop.
Save rjzak/543a7228a3aa505d786cde9c64014c4b to your computer and use it in GitHub Desktop.
A simple bloom filter with standard dependencies. Expects the input to be strings.
#!/usr/bin/python3
import os
import zlib
import pickle
import numpy as np
import unittest
P_B = 227
P_M = 1000005
def rabin_hash(ngram):
r = 0
for ng in ngram:
r = r * P_B + ord(ng)
r %= P_M
return abs(r)
def crc32_hash(ngram):
if type(ngram) == str:
ngram = ngram.encode()
return abs(zlib.crc32(ngram))
def adler32_hash(ngram):
if type(ngram) == str:
ngram = ngram.encode()
return abs(zlib.adler32(ngram))
class BloomFilter:
def __init__(self, length=0, fname=None):
if fname is not None:
with open(fname, "rb") as f:
tempDict = pickle.load(f)
self.length = tempDict["length"]
self.bits = tempDict["bits"]
else:
if length < 10:
print("Bloom Filter size of %d does't make sense, chaned to 1000" % length)
length = 1000
self.length = length
self.bits = np.zeros(length, dtype=np.bool)
self.hash_functions = (rabin_hash, crc32_hash, adler32_hash)
def insert(self, data):
for func in self.hash_functions:
index = func(data) % self.length
self.bits[index] = True
assert np.sum(self.bits) < self.length/4, "Bloom Filter exhausted"
def contains(self, data):
for func in self.hash_functions:
index = func(data) % self.length
if self.bits[index] == False:
return False
return True
def isempty(self):
return np.sum(self.bits) == 0
def save(self, fname):
tempDict = {"length": self.length, "bits": self.bits}
with open(fname, "wb") as f:
pickle.dump(tempDict, f)
class BloomFilterTest(unittest.TestCase):
def setUp(self):
self.filterSize = int(1e6)
def test_empty_is_empty(self):
bf = BloomFilter(self.filterSize)
self.assertTrue(bf.isempty())
def test_not_empty(self):
bf = BloomFilter(self.filterSize)
bf.insert("1234")
self.assertFalse(bf.isempty())
def test_did_insert(self):
bf = BloomFilter(self.filterSize)
tempVal = "1234999"
bf.insert(tempVal)
self.assertTrue(bf.contains(tempVal))
def test_not_contains(self):
bf = BloomFilter(self.filterSize)
bf.insert("1234")
bf.insert("abc123")
bf.insert("xyz000")
bf.insert("qwertyuiiop")
self.assertFalse(bf.contains("9999999999"))
def test_symbols(self):
bf = BloomFilter(self.filterSize)
bf.insert("{}|[]\;':&*(&)")
bf.insert("#$$^%")
self.assertFalse(bf.contains("?>?<>:"))
def test_serialization(self):
fname = "bloom_test.pkl"
testValue = "qazwsxedcrfvtgbyhnujmik,il."
bf = BloomFilter(self.filterSize)
bf.insert(testValue)
bf.save(fname)
newBF = BloomFilter(fname=fname)
os.remove(fname)
self.assertTrue(newBF.contains(testValue))
self.assertEqual(bf.length, newBF.length)
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment