Skip to content

Instantly share code, notes, and snippets.

@skojaku
Last active November 30, 2022 23:05
Show Gist options
  • Save skojaku/292e433a176f594fa428cc386d758d16 to your computer and use it in GitHub Desktop.
Save skojaku/292e433a176f594fa428cc386d758d16 to your computer and use it in GitHub Desktop.
Create faiss index
import faiss
import numpy as np
def make_faiss_index(
X, metric, gpu_id=None, exact=True, nprobe=10, min_cluster_size=10000
):
"""Create an index for the provided data
:param X: data to index
:type X: numpy.ndarray
:raises NotImplementedError: if the metric is not implemented
:param metric: metric to calculate the similarity. euclidean or cosine.
:type mertic: string
:param gpu_id: ID of the gpu, defaults to None (cpu).
:type gpu_id: string or None
:param exact: exact = True to find the true nearest neighbors. exact = False to find the almost nearest neighbors.
:type exact: boolean
:param nprobe: The number of cells for which search is performed. Relevant only when exact = False. Default to 10.
:type nprobe: int
:param min_cluster_size: Minimum cluster size. Only relevant when exact = False.
:type min_cluster_size: int
:return: faiss index
:rtype: faiss.Index
"""
n_samples, n_features = X.shape[0], X.shape[1]
X = X.astype("float32")
if n_samples < 1000:
exact = True
index = (
faiss.IndexFlatL2(n_features)
if metric == "euclidean"
else faiss.IndexFlatIP(n_features)
)
if not exact:
nlist = np.maximum(int(n_samples / min_cluster_size), 2)
faiss_metric = (
faiss.METRIC_L2 if metric == "euclidean" else faiss.METRIC_INNER_PRODUCT
)
index = faiss.IndexIVFFlat(index, n_features, int(nlist), faiss_metric)
if gpu_id != "cpu":
res = faiss.StandardGpuResources()
index = faiss.index_cpu_to_gpu(res, gpu_id, index)
if not index.is_trained:
Xtrain = X[
np.random.choice(
X.shape[0],
np.minimum(X.shape[0], min_cluster_size * 5),
replace=False,
),
:,
].copy(order="C")
index.train(Xtrain)
index.add(X)
index.nprobe = nprobe
return index
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment