Skip to content

Instantly share code, notes, and snippets.

@lemon24
Last active Aug 30, 2021
Embed
What would you like to do?
"""
(Broken?) MinHash implementation attempting to optimize Jaccard similarity
for https://github.com/lemon24/reader/issues/202 (reader.entry_dedupe plugin).
---
Current state:
It kinda works, but no matter how much I increase LOOPS,
the result doesn't seem to converge to the real similarity
(e.g. for 1_000_000, the weighted version has differences up to 0.10,
not much different from the 1000 ones; ???).
This is likely due to a bug, or me misunderstanding the algorithm.
---
Even if we fix it:
The L5-Minhash.pdf doc below says that to reduce the error rate
to < 0.05 99% of the times we need ~1000 loops.
I don't think this would be an optimiztion for "online" use.
(Haven't really thought how document size affects things, though.)
*However*, if we use a known array of randoms and a stable hash function,
we can precompute min_hashes and store it for each entry "offline",
and then just compare the hashes to get the actual similarity.
We might also use this to bucket/group entries by similarity.
---
Obviously, I don't need to roll my own;
datasketch below seems to do a great job.
---
https://en.wikipedia.org/wiki/Jaccard_index
https://www.cs.utah.edu/~jeffp/teaching/cs5140-S15/cs5140/L4-Jaccard+nGram.pdf
https://en.wikipedia.org/wiki/MinHash
https://en.wikipedia.org/wiki/MinHash#Incorporating_weights
https://www.cs.utah.edu/~jeffp/teaching/cs5140-S15/cs5140/L5-Minhash.pdf
---
$ time python minhash.py
js mh dmh jsw mhw dmhn dmhw
0.50 0.41 0.48 0.95 0.94 0.94 0.99 one=40 one=39 two=1
1.00 1.00 1.00 0.95 0.94 0.93 0.00 one=40 one=38
0.14 0.17 0.14 0.72 0.72 0.72 0.92 one=40 one=37 two=3
0.17 0.13 0.16 0.79 0.73 0.76 0.00 one=50 one=47 two=1 three=1
0.12 0.18 0.12 0.82 0.80 0.83 0.00 one=50 one=55 two=5
0.11 0.13 0.12 0.75 0.71 0.73 0.00 one=70 one=63 two=5 three=1
0.12 0.15 0.12 0.68 0.59 0.68 0.88 one=70 one=60 two=10
0.00 0.02 0.13 0.00 0.23 0.14 0.25 times
python minhash.py 1.45s user 0.12s system 113% cpu 1.393 total
"""
from collections import Counter
import sys
import random
import hashlib
import time
from itertools import groupby
from reader.plugins.entry_dedupe import _ngrams
sys.path.append('tests')
import test_plugins_entry_dedupe
from datasketch import MinHash, WeightedMinHashGenerator
DATA = []
for one, two, _ in test_plugins_entry_dedupe.IS_DUPLICATE_DATA:
if not one.summary or 'one' not in one.summary:
continue
DATA.append((one.summary.split(), two.summary.split()))
def to_str(value):
parts = []
parts.extend(f'{k}={v}' for k, v in Counter(value).items())
return ' '.join(parts)
def jaccard(one, two, n):
one = set(_ngrams(one, n))
two = set(_ngrams(two, n))
return len(one & two) / len(one | two)
def jaccard_weighted(one, two, n):
one = Counter(_ngrams(one, n))
two = Counter(_ngrams(two, n))
return sum((one & two).values()) / sum((one | two).values())
LOOPS = 1000
# sys.maxsize because https://stackoverflow.com/a/19133757
HASH_MAX = sys.maxsize
HASH = hash
# at least 5x slower
#HASH_MAX = b'\xff' * hashlib.md5().digest_size
#def HASH(thing): return hashlib.md5(repr(thing).encode('utf-8')).digest()
def minhash(one, two, n):
one = set(_ngrams(one, n))
two = set(_ngrams(two, n))
loops = LOOPS
min_hashes = [[HASH_MAX] * 2 for _ in range(loops)]
randoms = [random.random() for _ in range(loops)]
for ic, counts in enumerate((one, two)):
for t in counts:
for ir, r in enumerate(randoms):
h = HASH((r, t))
if h < min_hashes[ir][ic]:
min_hashes[ir][ic] = h
sim = sum(h_one == h_two for h_one, h_two in min_hashes) / loops
return sim
def minhash_weighted(one, two, n):
one = Counter(_ngrams(one, n))
two = Counter(_ngrams(two, n))
loops = LOOPS
min_hashes = [[HASH_MAX] * 2 for _ in range(loops)]
randoms = [random.random() for _ in range(loops)]
for ic, counts in enumerate((one, two)):
for t in counts:
for ir, r in enumerate(randoms):
for ix in range(counts[t]):
h = HASH((r, ix, t))
if h < min_hashes[ir][ic]:
min_hashes[ir][ic] = h
sim = sum(h_one == h_two for h_one, h_two in min_hashes) / loops
return sim
def datasketch_minhash(one, two, n):
m_one = MinHash(num_perm=LOOPS)
for t in _ngrams(one, n):
m_one.update(' '.join(t).encode('utf-8'))
m_two = MinHash(num_perm=LOOPS)
for t in _ngrams(two, n):
m_two.update(' '.join(t).encode('utf-8'))
return m_one.jaccard(m_two)
def datasketch_minhash_weighted(one, two, n):
# this one only works for same-size sets
one = list(one)
two = list(two)
if len(one) != len(two):
return 0
hashfunc = MinHash().hashfunc
one = list(hashfunc(' '.join(t).encode('utf-8')) for t in _ngrams(one, n))
two = list(hashfunc(' '.join(t).encode('utf-8')) for t in _ngrams(two, n))
gen = WeightedMinHashGenerator(len(one), LOOPS)
m_one = gen.minhash(one)
m_two = gen.minhash(two)
return m_one.jaccard(m_two)
def enumerated_ngrams(it, n):
for _, group in groupby(sorted(_ngrams(it, n))):
for i, t in enumerate(group):
yield t + (str(i),)
def datasketch_minhash_weighted_naive(one, two, n):
m_one = MinHash(num_perm=LOOPS)
for t in enumerated_ngrams(one, n):
m_one.update(' '.join(t).encode('utf-8'))
m_two = MinHash(num_perm=LOOPS)
for t in enumerated_ngrams(two, n):
m_two.update(' '.join(t).encode('utf-8'))
return m_one.jaccard(m_two)
impls = {
'js': jaccard,
'mh': minhash,
'dmh': datasketch_minhash,
'jsw': jaccard_weighted,
'mhw': minhash_weighted,
'dmhn': datasketch_minhash_weighted_naive,
'dmhw': datasketch_minhash_weighted,
}
print(''.join(f'{l:>4} ' for l in impls))
times = {}
for one, two in DATA:
sims = []
for n, fn in impls.items():
start = time.perf_counter()
val = fn(one, two, 4)
end = time.perf_counter()
sims.append(val)
times[n] = times.get(n, 0) + end - start
print(
f"{''.join(f'{s:.2f} ' for s in sims)}"
f"{to_str(one):2} "
f"{to_str(two):2}"
)
print(
f"{''.join(f'{s:.2f} ' for s in times.values())}"
"times"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment