Created
December 19, 2017 17:42
-
-
Save ixaxaar/5e4fa3efd70f8c8aabf200f7c34ab195 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import faiss | |
from faiss import cast_integer_to_float_ptr as cast_float | |
from faiss import cast_integer_to_int_ptr as cast_int | |
from faiss import cast_integer_to_long_ptr as cast_long | |
import torch as T | |
def ptr(tensor): | |
if T.is_tensor(tensor): | |
return tensor.storage().data_ptr() | |
elif hasattr(tensor, 'data'): | |
return tensor.data.storage().data_ptr() | |
else: | |
return tensor | |
def ensure_gpu(tensor, gpu_id): | |
if "cuda" in str(type(tensor)) and gpu_id != -1: | |
return tensor.cuda(gpu_id) | |
elif "cuda" in str(type(tensor)): | |
return tensor.cpu() | |
elif "Tensor" in str(type(tensor)) and gpu_id != -1: | |
return tensor.cuda(gpu_id) | |
elif "Tensor" in str(type(tensor)): | |
return tensor | |
elif type(tensor) is np.ndarray: | |
return cudavec(tensor, gpu_id=gpu_id).data | |
else: | |
return tensor | |
class FAISSIndex(object): | |
def __init__(self, cell_size=20, nr_cells=1024, K=4, num_lists=32, probes=32, res=None, train=None, gpu_id=-1): | |
super(FAISSIndex, self).__init__() | |
self.cell_size = cell_size | |
self.nr_cells = nr_cells | |
self.probes = probes | |
self.K = K | |
self.num_lists = num_lists | |
self.gpu_id = gpu_id | |
res = res if res else faiss.StandardGpuResources() | |
res.setTempMemoryFraction(0.01) | |
if self.gpu_id != -1: | |
res.initializeForDevice(self.gpu_id) | |
nr_samples = self.nr_cells * 100 * self.cell_size | |
train = train if train is not None else T.randn(self.nr_cells * 100, self.cell_size) * 10 | |
self.index = faiss.GpuIndexIVFFlat(res, self.cell_size, self.num_lists, faiss.METRIC_INNER_PRODUCT) | |
self.index.setNumProbes(self.probes) | |
self.train(train) | |
def cuda(self, gpu_id): | |
self.gpu_id = gpu_id | |
def train(self, train): | |
train = ensure_gpu(train, -1) | |
T.cuda.synchronize() | |
self.index.train_c(self.nr_cells, cast_float(ptr(train))) | |
T.cuda.synchronize() | |
def reset(self): | |
T.cuda.synchronize() | |
self.index.reset() | |
T.cuda.synchronize() | |
def add(self, other, positions=None, last=None): | |
other = ensure_gpu(other, self.gpu_id) | |
T.cuda.synchronize() | |
if positions is not None: | |
positions = ensure_gpu(positions, self.gpu_id) | |
assert positions.size(0) == other.size(0), "Mismatch in number of positions and vectors" | |
self.index.add_with_ids_c(other.size(0), cast_float(ptr(other)), cast_long(ptr(positions + 1))) | |
else: | |
other = other[:last, :] if last is not None else other | |
self.index.add_c(other.size(0), cast_float(ptr(other))) | |
T.cuda.synchronize() | |
def search(self, query, k=None): | |
query = ensure_gpu(query, self.gpu_id) | |
k = k if k else self.K | |
(b,n) = query.size() | |
distances = T.FloatTensor(b, k) | |
labels = T.LongTensor(b, k) | |
if self.gpu_id != -1: distances = distances.cuda(self.gpu_id) | |
if self.gpu_id != -1: labels = labels.cuda(self.gpu_id) | |
T.cuda.synchronize() | |
self.index.search_c( | |
b, | |
cast_float(ptr(query)), | |
k, | |
cast_float(ptr(distances)), | |
cast_long(ptr(labels)) | |
) | |
T.cuda.synchronize() | |
return (distances, (labels-1)) | |
def test_indexes(): | |
n = 1 | |
cell_size=20 | |
nr_cells=100 | |
K=10 | |
probes=32 | |
d = T.ones(n, cell_size) | |
q = T.ones(1, cell_size) | |
for gpu_id in (0, 0): | |
i = FAISSIndex(cell_size=cell_size, nr_cells=nr_cells, K=K, probes=probes, gpu_id=gpu_id) | |
d = d if gpu_id == -1 else d.cuda(gpu_id) | |
for x in range(10): | |
print("adding") | |
i.add(d) | |
print("adding") | |
i.add(d * 2) | |
print("adding") | |
i.add(d * 3) | |
print("adding") | |
dist, labels = i.search(q*7) | |
i.add(d*7, (T.Tensor([1,2,3])*37).long().cuda(gpu_id)) | |
print("adding") | |
i.add(d*7, (T.Tensor([1,2,3])*19).long().cuda(gpu_id)) | |
print("adding") | |
i.add(d*7, (T.Tensor([1,2,3])*17).long().cuda(gpu_id)) | |
dist, labels = i.search(q*7) | |
assert dist.size() == T.Size([1,K]) | |
assert labels.size() == T.Size([1, K]) | |
assert 37 in list(labels[0].cpu().numpy()) | |
assert 19 in list(labels[0].cpu().numpy()) | |
assert 17 in list(labels[0].cpu().numpy()) | |
if __name__ == '__main__': | |
test_indexes() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment