Skip to content

Instantly share code, notes, and snippets.

@mdouze
Created December 7, 2021 09:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mdouze/39fad49ae4a1d7888d10d21115f4ce73 to your computer and use it in GitHub Desktop.
Save mdouze/39fad49ae4a1d7888d10d21115f4ce73 to your computer and use it in GitHub Desktop.
#! /usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import print_function
import numpy as np
import argparse
import os
import time
import faiss
import datasets
def main():
parser = argparse.ArgumentParser()
def aa(*args, **kwargs):
group.add_argument(*args, **kwargs)
group = parser.add_argument_group('dataset options')
aa('--db', default='deep1M', help='dataset')
aa('--compute_gt', default=False, action='store_true',
help='compute and store the groundtruth')
aa('--override_simdir', default='', help='override the input directory')
group = parser.add_argument_group('codec training')
aa('--factorykey', default='Flat', help='index_factory type')
aa('--maxtrain', default=256 * 256, type=int,
help='maximum number of training points (0 to set automatically)')
aa('--codecfile', default='', help='file to read or write codec from')
group = parser.add_argument_group('evalutation')
aa('--eval_reconstruction', default=False, action='store_true',
help='evaluate reconstruction error')
aa('--eval_search', default=False, action='store_true',
help='evaluate search accuracy')
aa('--eval_search_hamming', default=False, action='store_true',
help='evaluate search accuracy with Hamming distances on codes')
args = parser.parse_args()
print("args:", args)
os.system('echo -n "nb processors "; '
'cat /proc/cpuinfo | grep ^processor | wc -l; '
'cat /proc/cpuinfo | grep ^"model name" | tail -1')
######################################################
# Load dataset
######################################################
if args.override_simdir:
print('overriding simdir to ', args.override_simdir)
datasets.simdir = args.override_simdir
xt, xb, xq, gt = datasets.load_data(
dataset=args.db, compute_gt=args.compute_gt)
print("dataset sizes: train %s base %s query %s GT %s" % (
xt.shape, xb.shape, xq.shape, gt.shape))
nq, d = xq.shape
nb, d = xb.shape
######################################################
# Train
######################################################
if args.codecfile and os.path.exists(args.codecfile):
print("reading", args.codecfile)
index = faiss.read_index(args.codecfile)
else:
print("build codec, key=", args.factorykey)
index = faiss.index_factory(d, args.factorykey)
try:
index_ivf = faiss.extract_index_ivf(index)
except RuntimeError:
index_ivf = None
print('code size %d' % index.sa_code_size())
if index_ivf is not None:
index_ivf.replace_invlists(None)
maxtrain = args.maxtrain
if maxtrain == 0:
if 'IMI' in args.indexkey:
maxtrain = int(256 * 2 ** (np.log2(index_ivf.nlist) / 2))
elif 'IVF' in args.indexkey:
maxtrain = 50 * index_ivf.nlist
print("setting maxtrain to %d" % maxtrain)
args.maxtrain = maxtrain
print('load training data')
xt = np.ascontiguousarray(xt[:maxtrain], dtype='float32')
assert np.all(np.isfinite(xt))
print('train...')
t0 = time.time()
index.train(xt)
print(" train in %.3f s" % (time.time() - t0))
writer = faiss.VectorIOWriter()
faiss.write_index(index, writer)
ar_data = faiss.vector_to_array(writer.data)
print('size of codec %d' % ar_data.size)
######################################################
# Evaluate
######################################################
# encode and decode xb
xb = np.ascontiguousarray(xb, dtype='float32')
print('encoding xb')
t0 = time.time()
code = index.sa_encode(xb)
t1 = time.time()
print('encode time %.3f' % (t1 - t0))
use_gpu = faiss.get_num_gpus() > 0
if args.eval_reconstruction or args.eval_search:
print('decoding xb')
xb_decoded = index.sa_decode(code)
t2 = time.time()
print('decode time %.3f' % (t2 - t1))
if args.eval_reconstruction:
err = np.linalg.norm(xb - xb_decoded, axis=1).mean()
mag = np.linalg.norm(xb, axis=1).mean()
print('avg L2 norm of vectors (not squared): %g' % mag)
print('avg reconstruction error: %g' % err)
if args.eval_search:
indexL2 = faiss.IndexFlatL2(d)
indexL2.add(xb_decoded)
if use_gpu:
print('moving to GPU')
indexL2 = faiss.index_cpu_to_all_gpus(indexL2)
xq = np.ascontiguousarray(xq, dtype='float32')
print('asymmetric search')
datasets.evaluate(xq, gt, indexL2)
print('symmetric search')
qcode = index.sa_encode(xq)
xq_decoded = index.sa_decode(qcode)
datasets.evaluate(xq_decoded, gt, indexL2)
if args.eval_search_hamming:
xq = np.ascontiguousarray(xq, dtype='float32')
print('hamming search')
indexBin = faiss.IndexBinaryFlat(code.shape[1] * 8)
if use_gpu:
indexBin = faiss.GpuIndexBinaryFlat(
faiss.StandardGpuResources(), indexBin)
indexBin.add(code)
qcode = index.sa_encode(xq)
datasets.evaluate(qcode, gt, indexBin)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment