import numpy as np | |
import faiss | |
def search_knn(xq, xb, k, distance_type=faiss.METRIC_L2): | |
""" wrapper around the faiss knn functions without index """ | |
nq, d = xq.shape | |
nb, d2 = xb.shape | |
assert d == d2 | |
I = np.empty((nq, k), dtype='int64') | |
D = np.empty((nq, k), dtype='float32') | |
if distance_type == faiss.METRIC_L2: | |
heaps = faiss.float_maxheap_array_t() | |
heaps.k = k | |
heaps.nh = nq | |
heaps.val = faiss.swig_ptr(D) | |
heaps.ids = faiss.swig_ptr(I) | |
faiss.knn_L2sqr( | |
faiss.swig_ptr(xq), faiss.swig_ptr(xb), | |
d, nq, nb, heaps | |
) | |
elif distance_type == faiss.METRIC_INNER_PRODUCT: | |
heaps = faiss.float_minheap_array_t() | |
heaps.k = k | |
heaps.nh = nq | |
heaps.val = faiss.swig_ptr(D) | |
heaps.ids = faiss.swig_ptr(I) | |
faiss.knn_inner_product( | |
faiss.swig_ptr(xq), faiss.swig_ptr(xb), | |
d, nq, nb, heaps | |
) | |
return D, I | |
# test for function above | |
xb = np.random.rand(200, 32).astype('float32') | |
xq = np.random.rand(100, 32).astype('float32') | |
index = faiss.IndexFlatL2(32) | |
index.add(xb) | |
Dref, Iref = index.search(xq, 10) | |
Dnew, Inew = search_knn(xq, xb, 10) | |
assert np.all(Inew == Iref) | |
assert np.allclose(Dref, Dnew) | |
index = faiss.IndexFlatIP(32) | |
index.add(xb) | |
Dref, Iref = index.search(xq, 10) | |
Dnew, Inew = search_knn(xq, xb, 10, distance_type=faiss.METRIC_INNER_PRODUCT) | |
assert np.all(Inew == Iref) | |
assert np.allclose(Dref, Dnew) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment