Skip to content

Instantly share code, notes, and snippets.

@matsui528
Created December 3, 2018 16:25
Show Gist options
  • Save matsui528/5683d0c9b2dc55c38b97e95879b94821 to your computer and use it in GitHub Desktop.
Save matsui528/5683d0c9b2dc55c38b97e95879b94821 to your computer and use it in GitHub Desktop.
Hyper-parameter tuning for faiss using optuna
# Test for faiss with optuna using siftsmall data
#
# (1) install libs:
# $ pip install optuna
# $ conda install faiss-cpu -c pytorch
#
# (2) Put the following util scripts in the same directory
# https://github.com/matsui528/rii/blob/master/examples/benchmark/util.py
#
# (3) download siftsmall data
# $ wget ftp://ftp.irisa.fr/local/texmex/corpus/siftsmall.tar.gz -P data
# $ tar -zxvf data/siftsmall.tar.gz -C data
#
# (4) run the script
# $ python run_optuna.py
import optuna
import faiss
import numpy as np
import time
import util
# Read data (train, base, query, groundtruth)
Xt = util.fvecs_read("./data/siftsmall/siftsmall_learn.fvecs")
Xb = util.fvecs_read("./data/siftsmall/siftsmall_base.fvecs")
Xq = util.fvecs_read("./data/siftsmall/siftsmall_query.fvecs")
gt = util.ivecs_read("./data/siftsmall/siftsmall_groundtruth.ivecs")
D = Xt.shape[1]
def run_search(index, Xq, gt, r):
"""
Given a faiss index, run the search. Return the runtime and the accuracy
Args:
index (faiss index): Faiss index for search
Xq (np.array): Query vectors. shape=(Nq, D). dtype=np.float32
gt (np.array): Groundtruth. shape=(Nq, ANY). dtype=np.int32
r (int): Top R
Returns:
(float, float): Duration [sec/query] and recall@r over the queries
"""
assert Xq.ndim == 2
assert Xq.dtype == np.float32
Nq = Xq.shape[0]
t0 = time.time()
_, I = index.search(x=Xq, k=r)
t1 = time.time()
duration = (t1 - t0) / Nq # sec/query
recall = util.recall_at_r(I, gt, r)
return duration, recall
def objective(trial):
# Setup parameters to be optimized
M = int(trial.suggest_categorical('M', ['4', '8', '16']))
nlist = trial.suggest_int('nlist', 10, 1000)
hnsw_m = trial.suggest_int('hnsw_m', 8, 64)
# Instantiate
quantizer = faiss.IndexHNSWFlat(D, hnsw_m)
index = faiss.IndexIVFPQ(quantizer, D, nlist, M, 8)
# Train the system,add base vectors to be searched
index.train(Xt)
index.add(Xb)
# Run search
duration, recall = run_search(index, Xq, gt, 1)
return -recall # flip recall, then min is better
# Run Optuna
study = optuna.create_study()
study.optimize(objective, n_trials=100)
# Show the best result
trial = study.best_trial
print("best recall:", -trial.value)
print("params:")
for k, v in trial.params.items():
print('{}: {}'.format(k, v))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment