Created
February 14, 2022 03:12
-
-
Save mooreniemi/43cfea4ddf32a0f6a42852e7e609c39f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{path::PathBuf, time::Instant}; | |
use faiss::{error::Error, index_factory, read_index, write_index, Index, MetricType}; | |
use itertools_num::linspace; | |
use structopt::StructOpt; | |
#[derive(Debug, StructOpt)] | |
#[structopt(name = "dann", about = "Demo ANN")] | |
struct DannOpt { | |
/// FAISS Index factory string | |
// https://www.pinecone.io/learn/composite-indexes/ | |
// below took ~13 minutes to train on my XPS13 | |
// "IVF8,PQ32x8" | |
#[structopt(default_value = "Flat", short, long)] | |
factory: String, | |
/// Index file | |
#[structopt(default_value = "/tmp/faiss.idx", parse(from_os_str))] | |
index_file_path: PathBuf, | |
} | |
fn main() -> Result<(), Error> { | |
let opt = DannOpt::from_args(); | |
let d: u32 = 64; | |
let num_vecs = 10_000; | |
// we multiply by d because of how train and add take data as a single contiguous array | |
// note this is just a sequence, not random vectors or real embeddings | |
let my_data: Vec<f32> = linspace::<f32>(0., 1., (d * num_vecs) as usize).collect(); | |
let mut index = index_factory(d, opt.factory, MetricType::L2)?; | |
if opt.index_file_path.exists() { | |
println!("Found a faiss.idx file so will load that as index"); | |
index = read_index(opt.index_file_path.to_str().expect("valid path"))?; | |
} else { | |
println!("Training..."); | |
// allows you to see the training iterations | |
index.set_verbose(true); | |
let start = Instant::now(); | |
// although not given same documentation, train also makes same assumption as add (see below) | |
index.train(&my_data)?; | |
write_index(&index, opt.index_file_path.to_str().expect("valid path"))?; | |
println!("Training took: {:?}", start.elapsed()); | |
println!("Adding data..."); | |
let start = Instant::now(); | |
// "This assumes a C-contiguous memory slice of vectors, where the total number of vectors is my_data.len() / d." | |
index.add(&my_data)?; | |
println!("Adding data took: {:?}", start.elapsed()); | |
}; | |
let k = 5; | |
let start = Instant::now(); | |
// just grab the "first" vector off the contiguous array for query vector | |
let result = index.search(&my_data[..64], k)?; | |
println!("Searching data took: {:?}", start.elapsed()); | |
for (i, (l, d)) in result | |
.labels | |
.iter() | |
.take(k) | |
.zip(result.distances.iter()) | |
.enumerate() | |
{ | |
println!("#{}: {} (D={})", i + 1, *l, *d); | |
} | |
Ok(()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment