-
-
Save mdouze/39fad49ae4a1d7888d10d21115f4ce73 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from __future__ import print_function | |
import numpy as np | |
import argparse | |
import os | |
import time | |
import faiss | |
import datasets | |
def main(): | |
parser = argparse.ArgumentParser() | |
def aa(*args, **kwargs): | |
group.add_argument(*args, **kwargs) | |
group = parser.add_argument_group('dataset options') | |
aa('--db', default='deep1M', help='dataset') | |
aa('--compute_gt', default=False, action='store_true', | |
help='compute and store the groundtruth') | |
aa('--override_simdir', default='', help='override the input directory') | |
group = parser.add_argument_group('codec training') | |
aa('--factorykey', default='Flat', help='index_factory type') | |
aa('--maxtrain', default=256 * 256, type=int, | |
help='maximum number of training points (0 to set automatically)') | |
aa('--codecfile', default='', help='file to read or write codec from') | |
group = parser.add_argument_group('evalutation') | |
aa('--eval_reconstruction', default=False, action='store_true', | |
help='evaluate reconstruction error') | |
aa('--eval_search', default=False, action='store_true', | |
help='evaluate search accuracy') | |
aa('--eval_search_hamming', default=False, action='store_true', | |
help='evaluate search accuracy with Hamming distances on codes') | |
args = parser.parse_args() | |
print("args:", args) | |
os.system('echo -n "nb processors "; ' | |
'cat /proc/cpuinfo | grep ^processor | wc -l; ' | |
'cat /proc/cpuinfo | grep ^"model name" | tail -1') | |
###################################################### | |
# Load dataset | |
###################################################### | |
if args.override_simdir: | |
print('overriding simdir to ', args.override_simdir) | |
datasets.simdir = args.override_simdir | |
xt, xb, xq, gt = datasets.load_data( | |
dataset=args.db, compute_gt=args.compute_gt) | |
print("dataset sizes: train %s base %s query %s GT %s" % ( | |
xt.shape, xb.shape, xq.shape, gt.shape)) | |
nq, d = xq.shape | |
nb, d = xb.shape | |
###################################################### | |
# Train | |
###################################################### | |
if args.codecfile and os.path.exists(args.codecfile): | |
print("reading", args.codecfile) | |
index = faiss.read_index(args.codecfile) | |
else: | |
print("build codec, key=", args.factorykey) | |
index = faiss.index_factory(d, args.factorykey) | |
try: | |
index_ivf = faiss.extract_index_ivf(index) | |
except RuntimeError: | |
index_ivf = None | |
print('code size %d' % index.sa_code_size()) | |
if index_ivf is not None: | |
index_ivf.replace_invlists(None) | |
maxtrain = args.maxtrain | |
if maxtrain == 0: | |
if 'IMI' in args.indexkey: | |
maxtrain = int(256 * 2 ** (np.log2(index_ivf.nlist) / 2)) | |
elif 'IVF' in args.indexkey: | |
maxtrain = 50 * index_ivf.nlist | |
print("setting maxtrain to %d" % maxtrain) | |
args.maxtrain = maxtrain | |
print('load training data') | |
xt = np.ascontiguousarray(xt[:maxtrain], dtype='float32') | |
assert np.all(np.isfinite(xt)) | |
print('train...') | |
t0 = time.time() | |
index.train(xt) | |
print(" train in %.3f s" % (time.time() - t0)) | |
writer = faiss.VectorIOWriter() | |
faiss.write_index(index, writer) | |
ar_data = faiss.vector_to_array(writer.data) | |
print('size of codec %d' % ar_data.size) | |
###################################################### | |
# Evaluate | |
###################################################### | |
# encode and decode xb | |
xb = np.ascontiguousarray(xb, dtype='float32') | |
print('encoding xb') | |
t0 = time.time() | |
code = index.sa_encode(xb) | |
t1 = time.time() | |
print('encode time %.3f' % (t1 - t0)) | |
use_gpu = faiss.get_num_gpus() > 0 | |
if args.eval_reconstruction or args.eval_search: | |
print('decoding xb') | |
xb_decoded = index.sa_decode(code) | |
t2 = time.time() | |
print('decode time %.3f' % (t2 - t1)) | |
if args.eval_reconstruction: | |
err = np.linalg.norm(xb - xb_decoded, axis=1).mean() | |
mag = np.linalg.norm(xb, axis=1).mean() | |
print('avg L2 norm of vectors (not squared): %g' % mag) | |
print('avg reconstruction error: %g' % err) | |
if args.eval_search: | |
indexL2 = faiss.IndexFlatL2(d) | |
indexL2.add(xb_decoded) | |
if use_gpu: | |
print('moving to GPU') | |
indexL2 = faiss.index_cpu_to_all_gpus(indexL2) | |
xq = np.ascontiguousarray(xq, dtype='float32') | |
print('asymmetric search') | |
datasets.evaluate(xq, gt, indexL2) | |
print('symmetric search') | |
qcode = index.sa_encode(xq) | |
xq_decoded = index.sa_decode(qcode) | |
datasets.evaluate(xq_decoded, gt, indexL2) | |
if args.eval_search_hamming: | |
xq = np.ascontiguousarray(xq, dtype='float32') | |
print('hamming search') | |
indexBin = faiss.IndexBinaryFlat(code.shape[1] * 8) | |
if use_gpu: | |
indexBin = faiss.GpuIndexBinaryFlat( | |
faiss.StandardGpuResources(), indexBin) | |
indexBin.add(code) | |
qcode = index.sa_encode(xq) | |
datasets.evaluate(qcode, gt, indexBin) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment