Skip to content

Instantly share code, notes, and snippets.

@tteofili
Created August 3, 2022 14:37
Show Gist options
  • Save tteofili/52a563fc67a7fc26fe27d4a69d6ec61e to your computer and use it in GitHub Desktop.
Save tteofili/52a563fc67a7fc26fe27d4a69d6ec61e to your computer and use it in GitHub Desktop.
/*
* Anserini: A Lucene toolkit for reproducible information retrieval research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.anserini.ann;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.time.DurationFormatUtils;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;
import org.kohsuke.args4j.OptionHandlerFilter;
import org.kohsuke.args4j.ParserProperties;
public class IndexVectorsHNSW {
public static final String FIELD_ID = "id";
public static final String FIELD_VECTOR = "vector";
public static final class Args {
@Option(name = "-input", metaVar = "[file]", required = true, usage = "vectors model")
public File input;
@Option(name = "-path", metaVar = "[path]", required = true, usage = "index path")
public Path path;
@Option(name="-stored", metaVar = "[boolean]", usage = "store vectors")
public boolean stored;
}
public static void main(String[] args) throws Exception {
IndexVectorsHNSW.Args indexArgs = new IndexVectorsHNSW.Args();
CmdLineParser parser = new CmdLineParser(indexArgs, ParserProperties.defaults().withUsageWidth(90));
try {
parser.parseArgument(args);
} catch (CmdLineException e) {
System.err.println(e.getMessage());
parser.printUsage(System.err);
System.err.println("Example: " + IndexVectorsHNSW.class.getSimpleName() +
parser.printExample(OptionHandlerFilter.REQUIRED));
return;
}
Analyzer vectorAnalyzer = null;
final long start = System.nanoTime();
System.out.println(String.format("Loading model %s", indexArgs.input));
Map<String, List<float[]>> vectors = readVectors(indexArgs.input);
Path indexDir = indexArgs.path;
if (!Files.exists(indexDir)) {
Files.createDirectories(indexDir);
}
System.out.println(String.format("Creating index at %s...", indexArgs.path));
Directory d = FSDirectory.open(indexDir);
IndexWriterConfig conf = new IndexWriterConfig();
IndexWriter indexWriter = new IndexWriter(d, conf);
final AtomicInteger cnt = new AtomicInteger();
for (Map.Entry<String, List<float[]>> entry : vectors.entrySet()) {
for (float[] vector: entry.getValue()) {
Document doc = new Document();
doc.add(new StringField(FIELD_ID, entry.getKey(), Field.Store.YES));
doc.add(new KnnVectorField(FIELD_VECTOR, vector, VectorSimilarityFunction.EUCLIDEAN));
try {
indexWriter.addDocument(doc);
int cur = cnt.incrementAndGet();
if (cur % 100000 == 0) {
System.out.println(String.format("%s docs added", cnt));
}
} catch (IOException e) {
System.err.println("Error while indexing: " + e.getLocalizedMessage());
}
}
}
indexWriter.commit();
System.out.println(String.format("%s docs indexed", cnt.get()));
long space = FileUtils.sizeOfDirectory(indexDir.toFile()) / (1024L * 1024L);
System.out.println(String.format("Index size: %dMB", space));
indexWriter.close();
d.close();
final long durationMillis =
TimeUnit.MILLISECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS);
System.out.println(String.format("Total time: %s",
DurationFormatUtils.formatDuration(durationMillis, "HH:mm:ss")));
}
static Map<String, List<float[]>> readVectors(File input) throws IOException {
Map<String, List<float[]>> vectors = new HashMap<>();
for (String line : IOUtils.readLines(new FileReader(input))) {
String[] s = line.split("\\s+");
if (s.length > 2) {
String key = s[0];
float[] vector = new float[s.length - 1];
for (int i = 1; i < s.length; i++) {
float f = Float.parseFloat(s[i]);
vector[i - 1] = f;
}
if (vectors.containsKey(key)) {
List<float[]> floats = new LinkedList<>(vectors.get(key));
floats.add(vector);
vectors.put(key, floats);
} else {
vectors.put(key, List.of(vector));
}
}
}
return vectors;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment