Skip to content

Instantly share code, notes, and snippets.

@mooreniemi
Created February 14, 2022 03:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mooreniemi/c4ed569f94ebca4a21dfac6c3dc7dd7a to your computer and use it in GitHub Desktop.
Save mooreniemi/c4ed569f94ebca4a21dfac6c3dc7dd7a to your computer and use it in GitHub Desktop.
use std::{
fs::File,
io::{BufRead, BufReader},
path::PathBuf,
time::Instant,
};
use faiss::{error::Error, index_factory, read_index, write_index, Index, MetricType};
use itertools_num::linspace;
use structopt::StructOpt;
#[derive(Debug, StructOpt)]
#[structopt(name = "dann", about = "Demo ANN")]
struct DannOpt {
/// FAISS Index factory string
// https://www.pinecone.io/learn/composite-indexes/
// below took ~13 minutes to train on my XPS13
// "IVF8,PQ32x8"
#[structopt(default_value = "Flat", short, long)]
factory: String,
/// Dimension of vectors
#[structopt(default_value = "64", short, long)]
dimension: u32,
/// Index file
#[structopt(default_value = "/tmp/faiss.idx", parse(from_os_str))]
index_file_path: PathBuf,
/// Data file, if none will generate 10_000 vectors
#[structopt(parse(from_os_str))]
data_file_path: Option<PathBuf>,
}
fn main() -> Result<(), Error> {
let opt = DannOpt::from_args();
let d: u32 = opt.dimension;
let data = if let Some(data_file_path) = opt.data_file_path {
// FIXME: could map_err into some Faiss error type or change to anyhow in main
let file = File::open(data_file_path).expect("data file exists");
BufReader::new(file)
.lines()
.flatten()
// FIXME: how should this file actually be laid out?
.flat_map(|line| line.parse::<f32>())
.collect()
} else {
// we multiply by d because of how train and add take data as a single contiguous array
let num_vecs = 10_000;
// NOTE: this is just a sequence, not random vectors or real embeddings
let seq: Vec<f32> = linspace::<f32>(0., 1., (d * num_vecs) as usize).collect();
seq
};
let mut index = index_factory(d, opt.factory, MetricType::L2)?;
if opt.index_file_path.exists() {
println!("Found a faiss.idx file so will load that as index");
let start = Instant::now();
index = read_index(opt.index_file_path.to_str().expect("valid path"))?;
println!("Reading index from file took: {:?}", start.elapsed());
} else {
println!("Training...");
// allows you to see the training iterations
index.set_verbose(true);
let start = Instant::now();
// although not given same documentation, train also makes same assumption as add (see below)
index.train(&data)?;
println!("Training took: {:?}", start.elapsed());
println!("Adding data...");
let start = Instant::now();
// "This assumes a C-contiguous memory slice of vectors, where the total number of vectors is data.len() / d."
index.add(&data)?;
println!("Adding data took: {:?}", start.elapsed());
println!("Saving index to file...");
let start = Instant::now();
write_index(&index, opt.index_file_path.to_str().expect("valid path"))?;
println!("Saving index to file took: {:?}", start.elapsed());
};
let k = 5;
let q: Vec<f32> = linspace::<f32>(0., 1., d as usize).collect();
let start = Instant::now();
let result = index.search(&q, k)?;
println!("Searching data took: {:?}", start.elapsed());
for (i, (l, d)) in result
.labels
.iter()
.take(k)
.zip(result.distances.iter())
.enumerate()
{
println!("#{}: {} (D={})", i + 1, *l, *d);
}
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment