Skip to content

Instantly share code, notes, and snippets.

@benwtrent
Created October 18, 2022 19:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save benwtrent/b73fb2245a65bdd1a8f1eb1775dd01bf to your computer and use it in GitHub Desktop.
Save benwtrent/b73fb2245a65bdd1a8f1eb1775dd01bf to your computer and use it in GitHub Desktop.
Ann Benchmark's integration using Lucene KNN.
"""
ann-benchmarks interface for Apache Lucene.
"""
import sklearn.preprocessing
import numpy as np
from struct import Struct
import lucene
from lucene import JArray
from java.nio.file import Paths
from org.apache.lucene.store import FSDirectory
from org.apache.lucene.search import KnnVectorQuery, IndexSearcher
from org.apache.lucene.index import IndexWriter, IndexWriterConfig, VectorSimilarityFunction, DirectoryReader
from org.apache.lucene.codecs.lucene91 import Lucene91Codec, Lucene91HnswVectorsFormat
from org.apache.lucene.document import Document, FieldType, KnnVectorField, StoredField
from ann_benchmarks.algorithms.base import BaseANN
class Codec(Lucene91Codec):
def __init__(self, M, efConstruction):
super(Codec, self).__init__()
self.M = M
self.efConstruction = efConstruction
def getKnnVectorsFormatForField(self):
Lucene91HnswVectorsFormat(self.M, self.efConstruction)
class LucenePyLucene(BaseANN):
"""
KNN using the Lucene Vector datatype.
"""
def __init__(self, metric: str, dimension: int, param):
lucene.initVM(vmargs=['-Djava.awt.headless=true -Xmx6g -Xms6g'])
self.metric = metric
self.dimension = dimension
self.param = param
self.short_name = f"lucenepyknn-{dimension}-{param['M']}-{param['efConstruction']}"
self.n_iters = -1
self.train_size = -1
self.simFunc = VectorSimilarityFunction.DOT_PRODUCT if self.metric == "angular" else VectorSimilarityFunction.EUCLIDEAN
if self.metric not in ("euclidean", "angular"):
raise NotImplementedError(f"Not implemented for metric {self.metric}")
def done(self):
if self.dir:
self.dir.close()
def fit(self, X):
# X is a numpy array
if self.dimension != X.shape[1]:
raise Exception(f"Configured dimension {self.dimension} but data has shape {X.shape}")
if self.metric == 'angular':
X = sklearn.preprocessing.normalize(X, axis=1, norm='l2')
self.train_size = X.shape[0]
iwc = IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.CREATE)
codec = Codec(self.param['M'], self.param['efConstruction'])
iwc.setCodec(codec)
iwc.setRAMBufferSizeMB(1994.0)
self.dir = FSDirectory.open(Paths.get(self.short_name + ".index"))
iw = IndexWriter(self.dir, iwc)
fieldType = KnnVectorField.createFieldType(self.dimension,self.simFunc)
id = 0
X = X.tolist()
for x in X:
doc = Document()
doc.add(KnnVectorField("knn", JArray('float')(x), fieldType))
doc.add(StoredField("id", id))
id += 1
iw.addDocument(doc)
iw.forceMerge(1)
iw.close()
self.searcher = IndexSearcher(DirectoryReader.open(self.dir))
def set_query_arguments(self, fanout):
self.name = f"lucenepyknn dim={self.dimension} {self.param}"
self.fanout = fanout
def query(self, q, n):
if self.metric == 'angular':
q = q / np.linalg.norm(q)
query = KnnVectorQuery("knn", JArray('float')(q.tolist()), n + self.fanout)
topdocs = self.searcher.search(query, n)
assert len(topdocs.scoreDocs) == n
return [d.doc for d in topdocs.scoreDocs]
#TODO BATCH QUERY
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment