Skip to content

Instantly share code, notes, and snippets.

@snakers4
Created August 26, 2018 12:00
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save snakers4/9e834c97e3f689aeb68d31b16f93be32 to your computer and use it in GitHub Desktop.
Save snakers4/9e834c97e3f689aeb68d31b16f93be32 to your computer and use it in GitHub Desktop.
Use faiss to calculate a KNN graph on data
import gc
import tqdm
import faiss
import bcolz
import os,sys
import numpy as np
from tqdm import tqdm
# open the stored bcolz array
# note that these vectors have to be 280 dimensional
# to be compatible with faiss indexing
# https://github.com/facebookresearch/faiss/wiki/Troubleshooting#gpu-precomputed-table-error
bc_path = 'your_vectors.bc'
bc_vectors = bcolz.open(rootdir=bc_path)
vectors = bc_vectors[:,:]
# create a bcolz array for a knn graph
knn_bc_path = 'knn.bc'
knn_bc = bcolz.carray(rootdir=knn_bc_path, mode='w')
knn_bc.flush()
# create a bcolz array for distances
knn_dist_bc_path = 'distances.bc'
knn_dist_bc = bcolz.carray(rootdir=knn_dist_bc_path, mode='w')
knn_dist_bc.flush()
res = faiss.StandardGpuResources()
index = faiss.index_factory(vectors.shape[1], "IVF4096,PQ56")
co = faiss.GpuClonerOptions()
# https://github.com/facebookresearch/faiss/tree/master/benchs
# here we are using a 64-byte PQ, so we must set the lookup tables to
# 16 bit float (this is due to the limited temporary memory).
co.useFloat16 = True
index = faiss.index_cpu_to_gpu(res, 0, index, co)
print("Train the index")
index.train(vectors)
print ('Add vectors to the index')
index.add(vectors)
del vectors
gc.collect()
nprobe = 1 << 8
index.setNumProbes(nprobe)
batch_size = int(16384/2)
l = list(range(0,len(bc_vectors)))
batches = [l[i:i + batch_size] for i in range(0, len(bc_vectors), batch_size)]
# check that the operation is valid
assert set([item for sublist in batches for item in sublist]) == set(list(range(0,len(bc_vectors))))
processed_batches = []
with tqdm(total=len(batches)) as pbar:
for batch in batches:
processed_batches.append(batch)
b_array = np.asarray(batch)
D, I = index.search(bc_vectors[b_array], 100)
knn_bc.append(I)
knn_bc.flush()
knn_dist_bc.append(D)
knn_dist_bc.flush()
pbar.update(1)
# check that all vectors were processed
assert set([item for sublist in processed_batches for item in sublist]) == set(list(range(0,len(bc_vectors))))
assert len(knn_bc) == len(bc_vectors)
assert len(knn_dist_bc) == len(bc_vectors)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment