Skip to content

Instantly share code, notes, and snippets.

@kaivalnp
Created October 18, 2023 20:42
Show Gist options
  • Save kaivalnp/79808017ed7666214540213d1e2a21cf to your computer and use it in GitHub Desktop.
Save kaivalnp/79808017ed7666214540213d1e2a21cf to your computer and use it in GitHub Desktop.
Threshold-based search benchmark
// Given some query and doc vectors, calculate the true baseline: "Count of vectors above a score threshold"
// The baseline for multiple thresholds is computed in a single run for saving time
private static int[][] trueRnnCounts(float[][] queries, float[][] docs, float[] thresholds) {
int[][] results = new int[thresholds.length][queries.length];
// Do this in parallel for calculating quickly
IntStream.range(0, queries.length).parallel().forEach(index -> {
for (float[] doc : docs) {
float score = DOT_PRODUCT.compare(queries[index], doc);
for (int thresholdIndex = 0; thresholdIndex < thresholds.length; thresholdIndex++) {
// Corresponding conversion for DOT_PRODUCT
float similarity = (1 + thresholds[thresholdIndex]) / 2;
if (score >= similarity) {
results[thresholdIndex][index]++;
}
}
}
});
return results;
}
// Define a run result. Note that `int[] results` is the "Count of vectors above the score threshold" for each query
private record RunResult(int[] results, long nanos, long numVisited) {}
// Run a topK-based search
private static RunResult knnCounts(float[][] queries, Path indexPath, String knnField, int topK, float threshold) {
try (Directory directory = FSDirectory.open(indexPath);
DirectoryReader reader = DirectoryReader.open(directory)) {
IndexSearcher searcher = new IndexSearcher(reader);
// Corresponding conversion for DOT_PRODUCT
float similarity = (1 + threshold) / 2;
long startTime = System.nanoTime();
LongAdder numVisited = new LongAdder();
int[] results = Arrays.stream(queries).mapToInt(query -> {
try {
KnnFloatVectorQuery vectorQuery = new KnnFloatVectorQuery(knnField, query, topK) {
@Override
protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
TopDocs merged = super.mergeLeafResults(perLeafResults);
// Post-filter results below the threshold
int index = -1 - Arrays.binarySearch(
merged.scoreDocs,
new ScoreDoc(Integer.MAX_VALUE, similarity),
Comparator.<ScoreDoc, Float>comparing(scoreDoc -> -scoreDoc.score).thenComparing(scoreDoc -> scoreDoc.doc));
if (index < merged.scoreDocs.length) {
merged.scoreDocs = Arrays.copyOf(merged.scoreDocs, index);
}
// Increment total count of HNSW nodes visited
numVisited.add(merged.totalHits.value);
return merged;
}
};
return searcher.count(vectorQuery);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}).toArray();
return new RunResult(results, System.nanoTime() - startTime, numVisited.longValue() / queries.length);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
// Run a threshold-based search
private static RunResult rnnCounts(float[][] queries, Path indexPath, String knnField, float traversalThreshold, float resultThreshold) {
try (Directory directory = FSDirectory.open(indexPath);
DirectoryReader reader = DirectoryReader.open(directory)) {
IndexSearcher searcher = new IndexSearcher(reader);
// Corresponding conversion for DOT_PRODUCT
float traversalSimilarity = (1 + traversalThreshold) / 2;
float resultSimilarity = (1 + resultThreshold) / 2;
long startTime = System.nanoTime();
LongAdder numVisited = new LongAdder();
int[] results = Arrays.stream(queries).mapToInt(query -> {
try {
RnnFloatVectorQuery vectorQuery = new RnnFloatVectorQuery(knnField, query, traversalSimilarity, resultSimilarity, null) {
@Override
protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
TopDocs merged = super.mergeLeafResults(perLeafResults);
// Increment total count of HNSW nodes visited
numVisited.add(merged.totalHits.value);
return merged;
}
};
return searcher.count(vectorQuery);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}).toArray();
return new RunResult(results, System.nanoTime() - startTime, numVisited.longValue() / queries.length);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment