Skip to content

Instantly share code, notes, and snippets.

@akorotkov
Created June 21, 2023 21:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save akorotkov/f9e85a4e9846b609851fd55451874c6d to your computer and use it in GitHub Desktop.
Save akorotkov/f9e85a4e9846b609851fd55451874c6d to your computer and use it in GitHub Desktop.
import numpy as np
import faiss
import seaborn as sns
import pandas as pd
# seed RNG for reproducible result
np.random.seed(1234)
nlist = 30
# number of results to return from each query
k = 10
# dataset size
nb = 100000
# number of queries used to calculate mean precision
nq = 10000
precision = []
# Number of dimensions
dimensions = [5, 10, 32, 64, 128, 256, 512, 1024, 1536]
for d in dimensions:
# Vector Dataset
xb = np.random.normal(size = (nb, d)).astype('float32')
xb = xb / np.linalg.norm(xb, axis=1)[:, np.newaxis]
# Query Dataset
xq = np.random.normal(size = (nq, d)).astype('float32')
xq = xq / np.linalg.norm(xq, axis=1)[:, np.newaxis]
# Define the index
quantizer = faiss.IndexFlatL2(d) # the other index
index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)
# Train the index
assert not index.is_trained
index.train(xb)
assert index.is_trained
# Load dataset into the index
index.add(xb)
# Collect the accurate results by setting nprobes == nlists
index.nprobe = nlist
# The accurate results for each query vector
_, correct_results = index.search(xq, k=10)
# For each number of probesm calculate precision@k
for nprobes in range(1,30):
index.nprobe = nprobes
# Results for the given number of nprobes
_, q_results = index.search(xq, k=k)
# Compare each row of the correct results with the query results
agg = []
for correct_row, q_row in zip(correct_results, q_results):
correct_row = set(correct_row)
q_row = set(q_row)
row_precision = float(len(q_row.intersection(correct_row))) / k
agg.append(row_precision)
# Add result to the output (dimensions, nprobes, precision)
rec = (d, nprobes, sum(agg) / len(agg))
print(rec)
precision.append(rec)
# Load dataset into a dataframe for plotting
df = pd.DataFrame(precision, columns=['vector_dim', 'probes', 'precision'])
# Plot the results
sns.lineplot(
data=df,
x='probes',
y='precision',
hue='vector_dim'
).set(title='precision@k=10 by probes')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment