Skip to content

Instantly share code, notes, and snippets.

@ixaxaar
Created December 19, 2017 17:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ixaxaar/5e4fa3efd70f8c8aabf200f7c34ab195 to your computer and use it in GitHub Desktop.
Save ixaxaar/5e4fa3efd70f8c8aabf200f7c34ab195 to your computer and use it in GitHub Desktop.
import faiss
from faiss import cast_integer_to_float_ptr as cast_float
from faiss import cast_integer_to_int_ptr as cast_int
from faiss import cast_integer_to_long_ptr as cast_long
import torch as T
def ptr(tensor):
if T.is_tensor(tensor):
return tensor.storage().data_ptr()
elif hasattr(tensor, 'data'):
return tensor.data.storage().data_ptr()
else:
return tensor
def ensure_gpu(tensor, gpu_id):
if "cuda" in str(type(tensor)) and gpu_id != -1:
return tensor.cuda(gpu_id)
elif "cuda" in str(type(tensor)):
return tensor.cpu()
elif "Tensor" in str(type(tensor)) and gpu_id != -1:
return tensor.cuda(gpu_id)
elif "Tensor" in str(type(tensor)):
return tensor
elif type(tensor) is np.ndarray:
return cudavec(tensor, gpu_id=gpu_id).data
else:
return tensor
class FAISSIndex(object):
def __init__(self, cell_size=20, nr_cells=1024, K=4, num_lists=32, probes=32, res=None, train=None, gpu_id=-1):
super(FAISSIndex, self).__init__()
self.cell_size = cell_size
self.nr_cells = nr_cells
self.probes = probes
self.K = K
self.num_lists = num_lists
self.gpu_id = gpu_id
res = res if res else faiss.StandardGpuResources()
res.setTempMemoryFraction(0.01)
if self.gpu_id != -1:
res.initializeForDevice(self.gpu_id)
nr_samples = self.nr_cells * 100 * self.cell_size
train = train if train is not None else T.randn(self.nr_cells * 100, self.cell_size) * 10
self.index = faiss.GpuIndexIVFFlat(res, self.cell_size, self.num_lists, faiss.METRIC_INNER_PRODUCT)
self.index.setNumProbes(self.probes)
self.train(train)
def cuda(self, gpu_id):
self.gpu_id = gpu_id
def train(self, train):
train = ensure_gpu(train, -1)
T.cuda.synchronize()
self.index.train_c(self.nr_cells, cast_float(ptr(train)))
T.cuda.synchronize()
def reset(self):
T.cuda.synchronize()
self.index.reset()
T.cuda.synchronize()
def add(self, other, positions=None, last=None):
other = ensure_gpu(other, self.gpu_id)
T.cuda.synchronize()
if positions is not None:
positions = ensure_gpu(positions, self.gpu_id)
assert positions.size(0) == other.size(0), "Mismatch in number of positions and vectors"
self.index.add_with_ids_c(other.size(0), cast_float(ptr(other)), cast_long(ptr(positions + 1)))
else:
other = other[:last, :] if last is not None else other
self.index.add_c(other.size(0), cast_float(ptr(other)))
T.cuda.synchronize()
def search(self, query, k=None):
query = ensure_gpu(query, self.gpu_id)
k = k if k else self.K
(b,n) = query.size()
distances = T.FloatTensor(b, k)
labels = T.LongTensor(b, k)
if self.gpu_id != -1: distances = distances.cuda(self.gpu_id)
if self.gpu_id != -1: labels = labels.cuda(self.gpu_id)
T.cuda.synchronize()
self.index.search_c(
b,
cast_float(ptr(query)),
k,
cast_float(ptr(distances)),
cast_long(ptr(labels))
)
T.cuda.synchronize()
return (distances, (labels-1))
def test_indexes():
n = 1
cell_size=20
nr_cells=100
K=10
probes=32
d = T.ones(n, cell_size)
q = T.ones(1, cell_size)
for gpu_id in (0, 0):
i = FAISSIndex(cell_size=cell_size, nr_cells=nr_cells, K=K, probes=probes, gpu_id=gpu_id)
d = d if gpu_id == -1 else d.cuda(gpu_id)
for x in range(10):
print("adding")
i.add(d)
print("adding")
i.add(d * 2)
print("adding")
i.add(d * 3)
print("adding")
dist, labels = i.search(q*7)
i.add(d*7, (T.Tensor([1,2,3])*37).long().cuda(gpu_id))
print("adding")
i.add(d*7, (T.Tensor([1,2,3])*19).long().cuda(gpu_id))
print("adding")
i.add(d*7, (T.Tensor([1,2,3])*17).long().cuda(gpu_id))
dist, labels = i.search(q*7)
assert dist.size() == T.Size([1,K])
assert labels.size() == T.Size([1, K])
assert 37 in list(labels[0].cpu().numpy())
assert 19 in list(labels[0].cpu().numpy())
assert 17 in list(labels[0].cpu().numpy())
if __name__ == '__main__':
test_indexes()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment