Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
import numpy as np
import faiss
def search_knn(xq, xb, k, distance_type=faiss.METRIC_L2):
""" wrapper around the faiss knn functions without index """
nq, d = xq.shape
nb, d2 = xb.shape
assert d == d2
I = np.empty((nq, k), dtype='int64')
D = np.empty((nq, k), dtype='float32')
if distance_type == faiss.METRIC_L2:
heaps = faiss.float_maxheap_array_t()
heaps.k = k
heaps.nh = nq
heaps.val = faiss.swig_ptr(D)
heaps.ids = faiss.swig_ptr(I)
faiss.knn_L2sqr(
faiss.swig_ptr(xq), faiss.swig_ptr(xb),
d, nq, nb, heaps
)
elif distance_type == faiss.METRIC_INNER_PRODUCT:
heaps = faiss.float_minheap_array_t()
heaps.k = k
heaps.nh = nq
heaps.val = faiss.swig_ptr(D)
heaps.ids = faiss.swig_ptr(I)
faiss.knn_inner_product(
faiss.swig_ptr(xq), faiss.swig_ptr(xb),
d, nq, nb, heaps
)
return D, I
# test for function above
xb = np.random.rand(200, 32).astype('float32')
xq = np.random.rand(100, 32).astype('float32')
index = faiss.IndexFlatL2(32)
index.add(xb)
Dref, Iref = index.search(xq, 10)
Dnew, Inew = search_knn(xq, xb, 10)
assert np.all(Inew == Iref)
assert np.allclose(Dref, Dnew)
index = faiss.IndexFlatIP(32)
index.add(xb)
Dref, Iref = index.search(xq, 10)
Dnew, Inew = search_knn(xq, xb, 10, distance_type=faiss.METRIC_INNER_PRODUCT)
assert np.all(Inew == Iref)
assert np.allclose(Dref, Dnew)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment