-
-
Save kaivalnp/79808017ed7666214540213d1e2a21cf to your computer and use it in GitHub Desktop.
Threshold-based search benchmark
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Given some query and doc vectors, calculate the true baseline: "Count of vectors above a score threshold" | |
// The baseline for multiple thresholds is computed in a single run for saving time | |
private static int[][] trueRnnCounts(float[][] queries, float[][] docs, float[] thresholds) { | |
int[][] results = new int[thresholds.length][queries.length]; | |
// Do this in parallel for calculating quickly | |
IntStream.range(0, queries.length).parallel().forEach(index -> { | |
for (float[] doc : docs) { | |
float score = DOT_PRODUCT.compare(queries[index], doc); | |
for (int thresholdIndex = 0; thresholdIndex < thresholds.length; thresholdIndex++) { | |
// Corresponding conversion for DOT_PRODUCT | |
float similarity = (1 + thresholds[thresholdIndex]) / 2; | |
if (score >= similarity) { | |
results[thresholdIndex][index]++; | |
} | |
} | |
} | |
}); | |
return results; | |
} | |
// Define a run result. Note that `int[] results` is the "Count of vectors above the score threshold" for each query | |
private record RunResult(int[] results, long nanos, long numVisited) {} | |
// Run a topK-based search | |
private static RunResult knnCounts(float[][] queries, Path indexPath, String knnField, int topK, float threshold) { | |
try (Directory directory = FSDirectory.open(indexPath); | |
DirectoryReader reader = DirectoryReader.open(directory)) { | |
IndexSearcher searcher = new IndexSearcher(reader); | |
// Corresponding conversion for DOT_PRODUCT | |
float similarity = (1 + threshold) / 2; | |
long startTime = System.nanoTime(); | |
LongAdder numVisited = new LongAdder(); | |
int[] results = Arrays.stream(queries).mapToInt(query -> { | |
try { | |
KnnFloatVectorQuery vectorQuery = new KnnFloatVectorQuery(knnField, query, topK) { | |
@Override | |
protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { | |
TopDocs merged = super.mergeLeafResults(perLeafResults); | |
// Post-filter results below the threshold | |
int index = -1 - Arrays.binarySearch( | |
merged.scoreDocs, | |
new ScoreDoc(Integer.MAX_VALUE, similarity), | |
Comparator.<ScoreDoc, Float>comparing(scoreDoc -> -scoreDoc.score).thenComparing(scoreDoc -> scoreDoc.doc)); | |
if (index < merged.scoreDocs.length) { | |
merged.scoreDocs = Arrays.copyOf(merged.scoreDocs, index); | |
} | |
// Increment total count of HNSW nodes visited | |
numVisited.add(merged.totalHits.value); | |
return merged; | |
} | |
}; | |
return searcher.count(vectorQuery); | |
} catch (IOException e) { | |
throw new UncheckedIOException(e); | |
} | |
}).toArray(); | |
return new RunResult(results, System.nanoTime() - startTime, numVisited.longValue() / queries.length); | |
} catch (IOException e) { | |
throw new UncheckedIOException(e); | |
} | |
} | |
// Run a threshold-based search | |
private static RunResult rnnCounts(float[][] queries, Path indexPath, String knnField, float traversalThreshold, float resultThreshold) { | |
try (Directory directory = FSDirectory.open(indexPath); | |
DirectoryReader reader = DirectoryReader.open(directory)) { | |
IndexSearcher searcher = new IndexSearcher(reader); | |
// Corresponding conversion for DOT_PRODUCT | |
float traversalSimilarity = (1 + traversalThreshold) / 2; | |
float resultSimilarity = (1 + resultThreshold) / 2; | |
long startTime = System.nanoTime(); | |
LongAdder numVisited = new LongAdder(); | |
int[] results = Arrays.stream(queries).mapToInt(query -> { | |
try { | |
RnnFloatVectorQuery vectorQuery = new RnnFloatVectorQuery(knnField, query, traversalSimilarity, resultSimilarity, null) { | |
@Override | |
protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { | |
TopDocs merged = super.mergeLeafResults(perLeafResults); | |
// Increment total count of HNSW nodes visited | |
numVisited.add(merged.totalHits.value); | |
return merged; | |
} | |
}; | |
return searcher.count(vectorQuery); | |
} catch (IOException e) { | |
throw new UncheckedIOException(e); | |
} | |
}).toArray(); | |
return new RunResult(results, System.nanoTime() - startTime, numVisited.longValue() / queries.length); | |
} catch (IOException e) { | |
throw new UncheckedIOException(e); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment