Skip to content

Instantly share code, notes, and snippets.

@mocobeta
Last active June 21, 2021 04:12
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mocobeta/532ea1508bfcb619d654cfa1743bb6a9 to your computer and use it in GitHub Desktop.
Save mocobeta/532ea1508bfcb619d654cfa1743bb6a9 to your computer and use it in GitHub Desktop.
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.StringField;
import org.apache.lucene.document.VectorField;
import org.apache.lucene.index.*;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnGraphQuery;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.util.hnsw.HNSWGraphReader;
import java.io.BufferedReader;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.*;
import java.nio.file.attribute.BasicFileAttributes;
public class EntityVectorSearch {
private static final String ENTITY_VECTORS_FILE = "/home/moco/tmp/jawiki-entitiy-vector/entity_vector.model.txt";
private static final String INDEX_DIR = "/home/moco/tmp/entity-vectors-index";
private static final String QUERY_VECTORS_FILE = "/home/moco/tmp/entity_vector_queries.txt";
public static void main(String[] args) {
try {
EntityVectorSearch evs = new EntityVectorSearch();
evs.indexVectors();
evs.searchVectors();
} catch (Exception e) {
e.printStackTrace();
}
}
void indexVectors() throws IOException {
cleanUp(INDEX_DIR);
Directory dir = FSDirectory.open(Paths.get(INDEX_DIR));
IndexWriterConfig config = new IndexWriterConfig();
config.setMergePolicy(NoMergePolicy.INSTANCE);
config.setUseCompoundFile(false);
IndexWriter writer = new IndexWriter(dir, config);
try {
try (BufferedReader br = Files.newBufferedReader(Paths.get(ENTITY_VECTORS_FILE), StandardCharsets.UTF_8)) {
String line = br.readLine(); // skip the first line
int count = 0;
while ((line = br.readLine()) != null) {
String[] cols = line.split(" ");
if (cols.length != 201) {
throw new IllegalArgumentException("invalid input: " + line);
}
String name = cols[0];
float[] vec = new float[200];
for (int i = 0; i < vec.length; i++) {
vec[i] = Float.parseFloat(cols[i+1]);
}
Document doc = new Document();
doc.add(new StringField("name", name, Field.Store.YES));
doc.add(new VectorField("vector", vec, VectorValues.DistanceFunction.COSINE));
writer.addDocument(doc);
System.out.println(name + " indexed.");
}
writer.commit();
}
} finally {
writer.close();
dir.close();
}
}
void searchVectors() throws IOException {
Directory dir = FSDirectory.open(Paths.get(INDEX_DIR));
IndexReader reader = DirectoryReader.open(dir);
IndexSearcher searcher = new IndexSearcher(reader);
HNSWGraphReader.loadGraphs("vector", reader, true);
try {
try(BufferedReader br = Files.newBufferedReader(Paths.get(QUERY_VECTORS_FILE), StandardCharsets.UTF_8)) {
String line;
while ((line = br.readLine()) != null) {
String cols[] = line.split(" ");
if (cols.length != 201) {
throw new IllegalArgumentException("invalid input: " + line);
}
String name = cols[0];
float[] queryVector = new float[200];
for (int i = 0; i < queryVector.length; i++) {
queryVector[i] = Float.parseFloat(cols[i+1]);
}
KnnGraphQuery query = new KnnGraphQuery("vector", queryVector, KnnGraphQuery.DEFAULT_EF);
long _start = System.currentTimeMillis();
TopDocs hits = searcher.search(query, 10);
long _end = System.currentTimeMillis();
System.out.println("== Search similar entity to " + name + " (elapsed=" + (_end - _start) + " msec) ==");
int rank = 0;
for (ScoreDoc sd : hits.scoreDocs) {
Document doc = reader.document(sd.doc);
System.out.println("\tRank " + ++rank + ": doc=" + sd.doc + " name=" + doc.get("name") + " score=" + sd.score);
}
}
}
} finally {
reader.close();
dir.close();
}
}
private static void cleanUp(String dir) throws IOException {
Path path = Paths.get(dir);
if (Files.exists(path)) {
Files.walkFileTree(path, new SimpleFileVisitor<Path>() {
@Override
public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException {
Files.delete(file);
return FileVisitResult.CONTINUE;
}
@Override
public FileVisitResult postVisitDirectory(Path dir, IOException exc) throws IOException {
Files.delete(dir);
return FileVisitResult.CONTINUE;
}
});
}
}
}
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.StringField;
import org.apache.lucene.document.VectorField;
import org.apache.lucene.index.*;
import org.apache.lucene.search.*;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.util.hnsw.HNSWGraphReader;
import java.io.BufferedReader;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.*;
import java.nio.file.attribute.BasicFileAttributes;
public class ParagraphVectorSearch {
private static final String PARAGRAPH_VECTORS_FILE = "/home/moco/tmp/jawiki-doc2vec/model/jawiki.doc2vec.dmpv300d.model.entries.txt";
private static final String INDEX_DIR = "/home/moco/tmp/paragraph-vectors-index";
private static final int NUM_DIMENSIONS = 300;
public static void main(String[] args) {
try {
ParagraphVectorSearch pvs = new ParagraphVectorSearch();
pvs.indexVectors();
pvs.moreLike("PayPay");
} catch (Exception e) {
e.printStackTrace();
}
}
void indexVectors() throws IOException {
cleanUp(INDEX_DIR);
Directory dir = FSDirectory.open(Paths.get(INDEX_DIR));
IndexWriterConfig config = new IndexWriterConfig();
config.setMergePolicy(NoMergePolicy.INSTANCE);
config.setUseCompoundFile(false);
IndexWriter writer = new IndexWriter(dir, config);
try {
try (BufferedReader br = Files.newBufferedReader(Paths.get(PARAGRAPH_VECTORS_FILE), StandardCharsets.UTF_8)) {
String line;
int count = 0;
while ((line = br.readLine()) != null) {
Entry entry = lineToEntry(line);
Document doc = new Document();
doc.add(new StringField("entry", entry.name, Field.Store.YES));
doc.add(new VectorField("vector", entry.vec, VectorValues.DistanceFunction.COSINE));
writer.addDocument(doc);
System.out.println(entry.name + " indexed. current indexed docs: " + ++count);
}
writer.commit();
}
} finally {
writer.close();
dir.close();
}
}
void moreLike(String entry) throws IOException {
Directory dir = FSDirectory.open(Paths.get(INDEX_DIR));
IndexReader reader = DirectoryReader.open(dir);
IndexSearcher searcher = new IndexSearcher(reader);
HNSWGraphReader.loadGraphs("vector", reader, true);
try {
TermQuery q = new TermQuery(new Term("entry", entry));
TopDocs hitsOne = searcher.search(q, 1);
if (hitsOne.totalHits.value != 1) {
throw new IllegalArgumentException("there should be only one doc; got " + hitsOne.totalHits.value);
}
int docId = hitsOne.scoreDocs[0].doc;
KnnGraphQuery query = KnnGraphQuery.like("vector", docId, KnnGraphQuery.DEFAULT_EF, reader, false);
long _start = System.currentTimeMillis();
TopDocs hits = searcher.search(query, 10);
long _end = System.currentTimeMillis();
System.out.println("== Search similar entity to " + entry + " (elapsed=" + (_end - _start) + " msec) ==");
int rank = 0;
for (ScoreDoc sd : hits.scoreDocs) {
Document doc = reader.document(sd.doc);
System.out.println("\tRank " + ++rank + ": doc=" + sd.doc + " name=" + doc.get("entry") + " score=" + sd.score);
}
} finally {
reader.close();
dir.close();
}
}
private static void cleanUp(String dir) throws IOException {
Path path = Paths.get(dir);
if (Files.exists(path)) {
Files.walkFileTree(path, new SimpleFileVisitor<Path>() {
@Override
public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException {
Files.delete(file);
return FileVisitResult.CONTINUE;
}
@Override
public FileVisitResult postVisitDirectory(Path dir, IOException exc) throws IOException {
Files.delete(dir);
return FileVisitResult.CONTINUE;
}
});
}
}
private static Entry lineToEntry(String line) {
String[] cols = line.split(" ");
StringBuilder sb = new StringBuilder();
int vecStart = cols.length - NUM_DIMENSIONS;
for (int i = 0; i < vecStart; i++) {
if (i > 0) {
sb.append("_");
}
sb.append(cols[i]);
}
String name = sb.toString();
float[] vec = new float[NUM_DIMENSIONS];
for (int i = 0; i < vec.length; i++) {
vec[i] = Float.parseFloat(cols[vecStart+i]);
}
Entry entry = new Entry();
entry.name = name;
entry.vec = vec;
return entry;
}
private static class Entry {
String name;
float[] vec;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment