Last active
November 30, 2022 23:05
-
-
Save skojaku/292e433a176f594fa428cc386d758d16 to your computer and use it in GitHub Desktop.
Create faiss index
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import faiss | |
import numpy as np | |
def make_faiss_index( | |
X, metric, gpu_id=None, exact=True, nprobe=10, min_cluster_size=10000 | |
): | |
"""Create an index for the provided data | |
:param X: data to index | |
:type X: numpy.ndarray | |
:raises NotImplementedError: if the metric is not implemented | |
:param metric: metric to calculate the similarity. euclidean or cosine. | |
:type mertic: string | |
:param gpu_id: ID of the gpu, defaults to None (cpu). | |
:type gpu_id: string or None | |
:param exact: exact = True to find the true nearest neighbors. exact = False to find the almost nearest neighbors. | |
:type exact: boolean | |
:param nprobe: The number of cells for which search is performed. Relevant only when exact = False. Default to 10. | |
:type nprobe: int | |
:param min_cluster_size: Minimum cluster size. Only relevant when exact = False. | |
:type min_cluster_size: int | |
:return: faiss index | |
:rtype: faiss.Index | |
""" | |
n_samples, n_features = X.shape[0], X.shape[1] | |
X = X.astype("float32") | |
if n_samples < 1000: | |
exact = True | |
index = ( | |
faiss.IndexFlatL2(n_features) | |
if metric == "euclidean" | |
else faiss.IndexFlatIP(n_features) | |
) | |
if not exact: | |
nlist = np.maximum(int(n_samples / min_cluster_size), 2) | |
faiss_metric = ( | |
faiss.METRIC_L2 if metric == "euclidean" else faiss.METRIC_INNER_PRODUCT | |
) | |
index = faiss.IndexIVFFlat(index, n_features, int(nlist), faiss_metric) | |
if gpu_id != "cpu": | |
res = faiss.StandardGpuResources() | |
index = faiss.index_cpu_to_gpu(res, gpu_id, index) | |
if not index.is_trained: | |
Xtrain = X[ | |
np.random.choice( | |
X.shape[0], | |
np.minimum(X.shape[0], min_cluster_size * 5), | |
replace=False, | |
), | |
:, | |
].copy(order="C") | |
index.train(Xtrain) | |
index.add(X) | |
index.nprobe = nprobe | |
return index |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment