Skip to content

Instantly share code, notes, and snippets.

@schaunwheeler
Last active June 8, 2023 23:22
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save schaunwheeler/5ac6fb4cc393f921fc8b8b55bc2ced2e to your computer and use it in GitHub Desktop.
Save schaunwheeler/5ac6fb4cc393f921fc8b8b55bc2ced2e to your computer and use it in GitHub Desktop.
Use MinHash to get Jaccard Similarity in Pyspark
from numpy.random import RandomState
import pyspark.sql.functions as f
from pyspark import StorageLevel
def hashmin_jaccard_spark(
sdf, node_col, edge_basis_col, suffixes=('A', 'B'),
n_draws=100, storage_level=None, seed=42, verbose=False):
"""
Calculate a sparse Jaccard similarity matrix using MinHash.
Parameters
sdf (pyspark.sql.DataFrame): A Dataframe containing at least two columns:
one defining the nodes (similarity between which is to be calculated)
and one defining the edges (the basis for node comparisons).
node_col (str): the name of the DataFrame column containing node labels
edge_basis_col: the name of the DataFrame columns containing the edge labels
suffixes (tuple): A tuple of length 2 contining the suffixes to be appeneded
to `node_col` in the output
n_draws (int): the number of permutations to do; this determines the precision
of the Jaccard similarity (n_draws == 100, the default, results in
similarity precision up to 0.01.
storage_level (pyspark.StorageLevel): PySpark object indicating how to persist
the hashing stage of the process
seed (int): seed for random state generation
verbose (bool): if True, print some information about how many records get hashed
"""
HASH_PRIME = 2038074743
left_name = node_col + suffixes[0]
right_name = node_col + suffixes[1]
rs = RandomState(seed)
shifts = rs.randint(0, HASH_PRIME - 1, size=n_draws)
coefs = rs.randint(0, HASH_PRIME - 1, size=n_draws) + 1
hash_sdf = (
sdf
.selectExpr(
"*",
*[
f"((1L + hash({edge_basis_col})) * {a} + {b}) % {HASH_PRIME} as hash{n}"
for n, (a, b) in enumerate(zip(coefs, shifts))
]
)
.groupBy(node_col)
.agg(
f.array(*[f.min(f"hash{n}") for n in range(n_draws)]).alias("minHash")
)
.select(
node_col,
f.posexplode(f.col('minHash')).alias('hashIndex', 'minHash')
)
.groupby('hashIndex', 'minHash')
.agg(
f.collect_list(f.col(node_col)).alias('nodeList'),
f.collect_set(f.col(node_col)).alias('nodeSet')
)
)
if storage_level is not None:
hash_sdf = hash_sdf.persist(storage_level)
hash_count = hash_sdf.count()
if verbose:
print('Hash dataframe count:', hash_count)
adj_sdf = (
hash_sdf.alias('a')
.join(hash_sdf.alias('b'), ['hashIndex', 'minHash'], 'inner')
.select(
f.col('minhash'),
f.explode(f.col('a.nodeList')).alias(left_name),
f.col('b.nodeSet')
)
.select(
f.col('minHash'),
f.col(left_name),
f.explode(f.col('nodeSet')).alias(right_name),
)
.groupby(left_name, right_name)
.agg((f.count('*') / n_draws).alias('jaccardSimilarity'))
)
return adj_sdf
@schaunwheeler
Copy link
Author

What do you hope to do with the words in sentence2? If you just want those words considered in the similarity calculation the same as the array of words in sentence, then just add the word in sentence2 to the array in sentence, explode the array, and use that exploded column as your edge basis column and id as your node column. If you want sentence2 to impact the similarity some other way, then that's beyond the scope of this function.

@jongbinjung
Copy link

Thanks for the gist. I was writing some unit tests and noticed that the error bounds are out-of-wack. I think you need to change line 45 from f"((1L + hash({edge_basis_col})) * {a} + {b}) % {HASH_PRIME} as hash{n}" to something like f"((1L + abs(hash({edge_basis_col}) % {HASH_PRIME})) * {a} + {b}) % {HASH_PRIME} as hash{n}"?

From the source where you got the hash function permutations, they cite this paper as proof that this family of hash functions work. But a condition for the proof is that, in the linear permutations $a \cdot x + b$, we must satisfy $x \in [p] = {0, 1, ..., p-1}$. In the current implementation, $x$ is effectively hash({edge_basis_col}) which can be any integer (even negative), so we need to force it to fall in $[p]$.

I don't know how spark's hash() works—so can't really check if this change actually makes the implementation theoretically sound ... but at least it passes my unit tests for theoretical error bounds.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment