Skip to content

Instantly share code, notes, and snippets.

@xhluca
Last active June 5, 2024 21:18
Show Gist options
  • Save xhluca/bbe800328002b360742edc2513cf4eb5 to your computer and use it in GitHub Desktop.
Save xhluca/bbe800328002b360742edc2513cf4eb5 to your computer and use it in GitHub Desktop.
Jax.lax's top_k vs np.argpartition for selecting topk results. Thanks: Andreas Madsen for original code & jax suggestions
import numpy as np
import scipy.sparse as sp
import jax.lax
import timeit
def random_sparse_matrix(n_rows, n_cols, num_elements, seed=0):
# Create a random scipy sprase matrix by randomly assigning values and constructing the matrix
np.random.seed(seed)
data = np.random.uniform(size=num_elements)
rows = np.random.randint(n_rows, size=num_elements)
cols = np.random.randint(n_cols, size=num_elements)
return sp.csc_matrix((data, (rows, cols)), shape=(n_rows, n_cols))
def topk(scores, k, sorted=True):
"""
This function is used to retrieve the top-k results for a single query. It will only work
on a 1-dimensional array of scores.
"""
partitioned_ind = np.argpartition(scores, -k)
partitioned_ind = partitioned_ind.take(indices=range(-k, 0))
partitioned_scores = np.take(scores, partitioned_ind)
if sorted:
sorted_trunc_ind = np.flip(np.argsort(partitioned_scores))
ind = partitioned_ind[sorted_trunc_ind]
scores = partitioned_scores[sorted_trunc_ind]
else:
ind = partitioned_ind
scores = partitioned_scores
return scores, ind
def get_scores(mat, query_tokens):
col_sum = np.asarray(mat[:, query_tokens].sum(axis=1)).squeeze()
return col_sum
n_docs = 5_000_000
n_vocab = 100_000
num_elements = 150_000_000
num_query_tokens = 200
mat = random_sparse_matrix(n_docs, n_vocab, num_elements)
query_tokens = np.random.choice(n_vocab, size=num_query_tokens, replace=False)
col_sum = get_scores(mat, query_tokens)
print(f"mat sparsity: {1 - mat.nnz / (n_docs * n_vocab)}")
print(f"sum sparsity: {1 - np.count_nonzero(col_sum) / n_docs}")
print("sum time: ", timeit.timeit(lambda: get_scores(mat, query_tokens), number=50))
for k in [100, 1000]:
print(f"topk [k={k}]: ", timeit.timeit(lambda: topk(col_sum, k), number=50))
for k in [100, 1000]:
print(
f"jax.lax.top_k [k={k}]: ",
timeit.timeit(lambda: jax.lax.top_k(col_sum, k), number=50),
)
for k in [100, 1000]:
print(
f"argpartition [k={k}]: ",
timeit.timeit(lambda: np.argpartition(col_sum, k), number=50),
)
for k in [100, 1000]:
# topk is different form k, since k is the k-th smallest element, while topk is the k largest elements
tk = len(col_sum) - k
print(
f"argpartition [k={tk}]: ",
timeit.timeit(lambda: np.argpartition(col_sum, tk), number=50),
)
# compare
jxval, jxinds = jax.lax.top_k(col_sum, 1000)
npval, npinds = topk(col_sum, 1000)
assert np.all(jxval == npval)
assert np.all(jxinds == npinds)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment