Last active
June 21, 2021 04:12
-
-
Save mocobeta/532ea1508bfcb619d654cfa1743bb6a9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.lucene.document.Document; | |
import org.apache.lucene.document.Field; | |
import org.apache.lucene.document.StringField; | |
import org.apache.lucene.document.VectorField; | |
import org.apache.lucene.index.*; | |
import org.apache.lucene.search.IndexSearcher; | |
import org.apache.lucene.search.KnnGraphQuery; | |
import org.apache.lucene.search.ScoreDoc; | |
import org.apache.lucene.search.TopDocs; | |
import org.apache.lucene.store.Directory; | |
import org.apache.lucene.store.FSDirectory; | |
import org.apache.lucene.util.hnsw.HNSWGraphReader; | |
import java.io.BufferedReader; | |
import java.io.IOException; | |
import java.nio.charset.StandardCharsets; | |
import java.nio.file.*; | |
import java.nio.file.attribute.BasicFileAttributes; | |
public class EntityVectorSearch { | |
private static final String ENTITY_VECTORS_FILE = "/home/moco/tmp/jawiki-entitiy-vector/entity_vector.model.txt"; | |
private static final String INDEX_DIR = "/home/moco/tmp/entity-vectors-index"; | |
private static final String QUERY_VECTORS_FILE = "/home/moco/tmp/entity_vector_queries.txt"; | |
public static void main(String[] args) { | |
try { | |
EntityVectorSearch evs = new EntityVectorSearch(); | |
evs.indexVectors(); | |
evs.searchVectors(); | |
} catch (Exception e) { | |
e.printStackTrace(); | |
} | |
} | |
void indexVectors() throws IOException { | |
cleanUp(INDEX_DIR); | |
Directory dir = FSDirectory.open(Paths.get(INDEX_DIR)); | |
IndexWriterConfig config = new IndexWriterConfig(); | |
config.setMergePolicy(NoMergePolicy.INSTANCE); | |
config.setUseCompoundFile(false); | |
IndexWriter writer = new IndexWriter(dir, config); | |
try { | |
try (BufferedReader br = Files.newBufferedReader(Paths.get(ENTITY_VECTORS_FILE), StandardCharsets.UTF_8)) { | |
String line = br.readLine(); // skip the first line | |
int count = 0; | |
while ((line = br.readLine()) != null) { | |
String[] cols = line.split(" "); | |
if (cols.length != 201) { | |
throw new IllegalArgumentException("invalid input: " + line); | |
} | |
String name = cols[0]; | |
float[] vec = new float[200]; | |
for (int i = 0; i < vec.length; i++) { | |
vec[i] = Float.parseFloat(cols[i+1]); | |
} | |
Document doc = new Document(); | |
doc.add(new StringField("name", name, Field.Store.YES)); | |
doc.add(new VectorField("vector", vec, VectorValues.DistanceFunction.COSINE)); | |
writer.addDocument(doc); | |
System.out.println(name + " indexed."); | |
} | |
writer.commit(); | |
} | |
} finally { | |
writer.close(); | |
dir.close(); | |
} | |
} | |
void searchVectors() throws IOException { | |
Directory dir = FSDirectory.open(Paths.get(INDEX_DIR)); | |
IndexReader reader = DirectoryReader.open(dir); | |
IndexSearcher searcher = new IndexSearcher(reader); | |
HNSWGraphReader.loadGraphs("vector", reader, true); | |
try { | |
try(BufferedReader br = Files.newBufferedReader(Paths.get(QUERY_VECTORS_FILE), StandardCharsets.UTF_8)) { | |
String line; | |
while ((line = br.readLine()) != null) { | |
String cols[] = line.split(" "); | |
if (cols.length != 201) { | |
throw new IllegalArgumentException("invalid input: " + line); | |
} | |
String name = cols[0]; | |
float[] queryVector = new float[200]; | |
for (int i = 0; i < queryVector.length; i++) { | |
queryVector[i] = Float.parseFloat(cols[i+1]); | |
} | |
KnnGraphQuery query = new KnnGraphQuery("vector", queryVector, KnnGraphQuery.DEFAULT_EF); | |
long _start = System.currentTimeMillis(); | |
TopDocs hits = searcher.search(query, 10); | |
long _end = System.currentTimeMillis(); | |
System.out.println("== Search similar entity to " + name + " (elapsed=" + (_end - _start) + " msec) =="); | |
int rank = 0; | |
for (ScoreDoc sd : hits.scoreDocs) { | |
Document doc = reader.document(sd.doc); | |
System.out.println("\tRank " + ++rank + ": doc=" + sd.doc + " name=" + doc.get("name") + " score=" + sd.score); | |
} | |
} | |
} | |
} finally { | |
reader.close(); | |
dir.close(); | |
} | |
} | |
private static void cleanUp(String dir) throws IOException { | |
Path path = Paths.get(dir); | |
if (Files.exists(path)) { | |
Files.walkFileTree(path, new SimpleFileVisitor<Path>() { | |
@Override | |
public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException { | |
Files.delete(file); | |
return FileVisitResult.CONTINUE; | |
} | |
@Override | |
public FileVisitResult postVisitDirectory(Path dir, IOException exc) throws IOException { | |
Files.delete(dir); | |
return FileVisitResult.CONTINUE; | |
} | |
}); | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.lucene.document.Document; | |
import org.apache.lucene.document.Field; | |
import org.apache.lucene.document.StringField; | |
import org.apache.lucene.document.VectorField; | |
import org.apache.lucene.index.*; | |
import org.apache.lucene.search.*; | |
import org.apache.lucene.store.Directory; | |
import org.apache.lucene.store.FSDirectory; | |
import org.apache.lucene.util.hnsw.HNSWGraphReader; | |
import java.io.BufferedReader; | |
import java.io.IOException; | |
import java.nio.charset.StandardCharsets; | |
import java.nio.file.*; | |
import java.nio.file.attribute.BasicFileAttributes; | |
public class ParagraphVectorSearch { | |
private static final String PARAGRAPH_VECTORS_FILE = "/home/moco/tmp/jawiki-doc2vec/model/jawiki.doc2vec.dmpv300d.model.entries.txt"; | |
private static final String INDEX_DIR = "/home/moco/tmp/paragraph-vectors-index"; | |
private static final int NUM_DIMENSIONS = 300; | |
public static void main(String[] args) { | |
try { | |
ParagraphVectorSearch pvs = new ParagraphVectorSearch(); | |
pvs.indexVectors(); | |
pvs.moreLike("PayPay"); | |
} catch (Exception e) { | |
e.printStackTrace(); | |
} | |
} | |
void indexVectors() throws IOException { | |
cleanUp(INDEX_DIR); | |
Directory dir = FSDirectory.open(Paths.get(INDEX_DIR)); | |
IndexWriterConfig config = new IndexWriterConfig(); | |
config.setMergePolicy(NoMergePolicy.INSTANCE); | |
config.setUseCompoundFile(false); | |
IndexWriter writer = new IndexWriter(dir, config); | |
try { | |
try (BufferedReader br = Files.newBufferedReader(Paths.get(PARAGRAPH_VECTORS_FILE), StandardCharsets.UTF_8)) { | |
String line; | |
int count = 0; | |
while ((line = br.readLine()) != null) { | |
Entry entry = lineToEntry(line); | |
Document doc = new Document(); | |
doc.add(new StringField("entry", entry.name, Field.Store.YES)); | |
doc.add(new VectorField("vector", entry.vec, VectorValues.DistanceFunction.COSINE)); | |
writer.addDocument(doc); | |
System.out.println(entry.name + " indexed. current indexed docs: " + ++count); | |
} | |
writer.commit(); | |
} | |
} finally { | |
writer.close(); | |
dir.close(); | |
} | |
} | |
void moreLike(String entry) throws IOException { | |
Directory dir = FSDirectory.open(Paths.get(INDEX_DIR)); | |
IndexReader reader = DirectoryReader.open(dir); | |
IndexSearcher searcher = new IndexSearcher(reader); | |
HNSWGraphReader.loadGraphs("vector", reader, true); | |
try { | |
TermQuery q = new TermQuery(new Term("entry", entry)); | |
TopDocs hitsOne = searcher.search(q, 1); | |
if (hitsOne.totalHits.value != 1) { | |
throw new IllegalArgumentException("there should be only one doc; got " + hitsOne.totalHits.value); | |
} | |
int docId = hitsOne.scoreDocs[0].doc; | |
KnnGraphQuery query = KnnGraphQuery.like("vector", docId, KnnGraphQuery.DEFAULT_EF, reader, false); | |
long _start = System.currentTimeMillis(); | |
TopDocs hits = searcher.search(query, 10); | |
long _end = System.currentTimeMillis(); | |
System.out.println("== Search similar entity to " + entry + " (elapsed=" + (_end - _start) + " msec) =="); | |
int rank = 0; | |
for (ScoreDoc sd : hits.scoreDocs) { | |
Document doc = reader.document(sd.doc); | |
System.out.println("\tRank " + ++rank + ": doc=" + sd.doc + " name=" + doc.get("entry") + " score=" + sd.score); | |
} | |
} finally { | |
reader.close(); | |
dir.close(); | |
} | |
} | |
private static void cleanUp(String dir) throws IOException { | |
Path path = Paths.get(dir); | |
if (Files.exists(path)) { | |
Files.walkFileTree(path, new SimpleFileVisitor<Path>() { | |
@Override | |
public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException { | |
Files.delete(file); | |
return FileVisitResult.CONTINUE; | |
} | |
@Override | |
public FileVisitResult postVisitDirectory(Path dir, IOException exc) throws IOException { | |
Files.delete(dir); | |
return FileVisitResult.CONTINUE; | |
} | |
}); | |
} | |
} | |
private static Entry lineToEntry(String line) { | |
String[] cols = line.split(" "); | |
StringBuilder sb = new StringBuilder(); | |
int vecStart = cols.length - NUM_DIMENSIONS; | |
for (int i = 0; i < vecStart; i++) { | |
if (i > 0) { | |
sb.append("_"); | |
} | |
sb.append(cols[i]); | |
} | |
String name = sb.toString(); | |
float[] vec = new float[NUM_DIMENSIONS]; | |
for (int i = 0; i < vec.length; i++) { | |
vec[i] = Float.parseFloat(cols[vecStart+i]); | |
} | |
Entry entry = new Entry(); | |
entry.name = name; | |
entry.vec = vec; | |
return entry; | |
} | |
private static class Entry { | |
String name; | |
float[] vec; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment