Skip to content

Instantly share code, notes, and snippets.

@tteofili
Created August 3, 2022 15:43
Show Gist options
  • Save tteofili/1ce98830854146a07796b93b790b8fca to your computer and use it in GitHub Desktop.
Save tteofili/1ce98830854146a07796b93b790b8fca to your computer and use it in GitHub Desktop.
/*
* Anserini: A Lucene toolkit for reproducible information retrieval research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.anserini.ann;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
import java.util.Map;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnVectorQuery;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;
import org.kohsuke.args4j.OptionHandlerFilter;
import org.kohsuke.args4j.ParserProperties;
import static io.anserini.ann.IndexVectorsHNSW.FIELD_VECTOR;
public class ApproximateNearestNeighborEvalHNSW {
public static final class Args {
@Option(name = "-input", metaVar = "[file]", required = true, usage = "vectors model")
public File input;
@Option(name = "-path", metaVar = "[path]", required = true, usage = "index path")
public Path path;
@Option(name = "-topics", metaVar = "[file]", required = true, usage = "path to TREC topics file")
public Path topicsPath;
@Option(name = "-topN", metaVar = "[int]", usage = "topN recall")
public int topN = 10;
@Option(name = "-depth", metaVar = "[int]", usage = "retrieval depth")
public int depth = 10;
@Option(name = "-samples", metaVar = "[int]", usage = "no. of samples")
public int samples = Integer.MAX_VALUE;
}
public static void main(String[] args) throws Exception {
ApproximateNearestNeighborEvalHNSW.Args indexArgs = new ApproximateNearestNeighborEvalHNSW.Args();
CmdLineParser parser = new CmdLineParser(indexArgs, ParserProperties.defaults().withUsageWidth(90));
try {
parser.parseArgument(args);
} catch (CmdLineException e) {
System.err.println(e.getMessage());
parser.printUsage(System.err);
System.err.println("Example: " + ApproximateNearestNeighborEvalHNSW.class.getSimpleName() +
parser.printExample(OptionHandlerFilter.REQUIRED));
return;
}
System.out.println(String.format("Loading model %s", indexArgs.input));
Map<String, List<float[]>> queryVectors = IndexVectorsHNSW.readVectors(indexArgs.input);
Path indexDir = indexArgs.path;
if (!Files.exists(indexDir)) {
Files.createDirectories(indexDir);
}
System.out.println(String.format("Reading index at %s", indexArgs.path));
Directory d = FSDirectory.open(indexDir);
DirectoryReader reader = DirectoryReader.open(d);
IndexSearcher searcher = new IndexSearcher(reader);
double time = 0d;
System.out.println("Evaluating at retrieval depth: " + indexArgs.depth);
int queryCount = 0;
FileWriter writer = new FileWriter("msmarco.hnsw.tsv");
for (Map.Entry<String, List<float[]>> queryVectorEntry : queryVectors.entrySet()) {
try {
List<float[]> vectors = queryVectorEntry.getValue();
for (float[] queryVector : vectors) {
KnnVectorQuery simQuery = new KnnVectorQuery(FIELD_VECTOR, queryVector, indexArgs.topN);
long start = System.currentTimeMillis();
TopDocs topDocs = searcher.search(simQuery, indexArgs.depth);
time += System.currentTimeMillis() - start;
int rank = 0;
for (ScoreDoc sd : topDocs.scoreDocs) {
Document document = reader.document(sd.doc);
String wordValue = document.get(IndexVectorsHNSW.FIELD_ID);
writer.append(queryVectorEntry.getKey()).append("\t").append(wordValue).append("\t")
.append(String.valueOf(rank)).append("\n");
rank++;
}
writer.flush();
queryCount++;
}
} catch (IOException e) {
System.err.println("search for '" + queryVectorEntry.getKey() + "' failed " + e.getLocalizedMessage());
}
if (queryCount >= indexArgs.samples) {
break;
}
}
time /= queryCount;
System.out.println(String.format("avg query time: %s ms", time));
writer.flush();
writer.close();
reader.close();
d.close();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment