Last active
June 5, 2024 21:18
-
-
Save xhluca/bbe800328002b360742edc2513cf4eb5 to your computer and use it in GitHub Desktop.
Jax.lax's top_k vs np.argpartition for selecting topk results. Thanks: Andreas Madsen for original code & jax suggestions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import scipy.sparse as sp | |
import jax.lax | |
import timeit | |
def random_sparse_matrix(n_rows, n_cols, num_elements, seed=0): | |
# Create a random scipy sprase matrix by randomly assigning values and constructing the matrix | |
np.random.seed(seed) | |
data = np.random.uniform(size=num_elements) | |
rows = np.random.randint(n_rows, size=num_elements) | |
cols = np.random.randint(n_cols, size=num_elements) | |
return sp.csc_matrix((data, (rows, cols)), shape=(n_rows, n_cols)) | |
def topk(scores, k, sorted=True): | |
""" | |
This function is used to retrieve the top-k results for a single query. It will only work | |
on a 1-dimensional array of scores. | |
""" | |
partitioned_ind = np.argpartition(scores, -k) | |
partitioned_ind = partitioned_ind.take(indices=range(-k, 0)) | |
partitioned_scores = np.take(scores, partitioned_ind) | |
if sorted: | |
sorted_trunc_ind = np.flip(np.argsort(partitioned_scores)) | |
ind = partitioned_ind[sorted_trunc_ind] | |
scores = partitioned_scores[sorted_trunc_ind] | |
else: | |
ind = partitioned_ind | |
scores = partitioned_scores | |
return scores, ind | |
def get_scores(mat, query_tokens): | |
col_sum = np.asarray(mat[:, query_tokens].sum(axis=1)).squeeze() | |
return col_sum | |
n_docs = 5_000_000 | |
n_vocab = 100_000 | |
num_elements = 150_000_000 | |
num_query_tokens = 200 | |
mat = random_sparse_matrix(n_docs, n_vocab, num_elements) | |
query_tokens = np.random.choice(n_vocab, size=num_query_tokens, replace=False) | |
col_sum = get_scores(mat, query_tokens) | |
print(f"mat sparsity: {1 - mat.nnz / (n_docs * n_vocab)}") | |
print(f"sum sparsity: {1 - np.count_nonzero(col_sum) / n_docs}") | |
print("sum time: ", timeit.timeit(lambda: get_scores(mat, query_tokens), number=50)) | |
for k in [100, 1000]: | |
print(f"topk [k={k}]: ", timeit.timeit(lambda: topk(col_sum, k), number=50)) | |
for k in [100, 1000]: | |
print( | |
f"jax.lax.top_k [k={k}]: ", | |
timeit.timeit(lambda: jax.lax.top_k(col_sum, k), number=50), | |
) | |
for k in [100, 1000]: | |
print( | |
f"argpartition [k={k}]: ", | |
timeit.timeit(lambda: np.argpartition(col_sum, k), number=50), | |
) | |
for k in [100, 1000]: | |
# topk is different form k, since k is the k-th smallest element, while topk is the k largest elements | |
tk = len(col_sum) - k | |
print( | |
f"argpartition [k={tk}]: ", | |
timeit.timeit(lambda: np.argpartition(col_sum, tk), number=50), | |
) | |
# compare | |
jxval, jxinds = jax.lax.top_k(col_sum, 1000) | |
npval, npinds = topk(col_sum, 1000) | |
assert np.all(jxval == npval) | |
assert np.all(jxinds == npinds) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment