Created
June 21, 2023 21:27
-
-
Save akorotkov/f9e85a4e9846b609851fd55451874c6d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import faiss | |
import seaborn as sns | |
import pandas as pd | |
# seed RNG for reproducible result | |
np.random.seed(1234) | |
nlist = 30 | |
# number of results to return from each query | |
k = 10 | |
# dataset size | |
nb = 100000 | |
# number of queries used to calculate mean precision | |
nq = 10000 | |
precision = [] | |
# Number of dimensions | |
dimensions = [5, 10, 32, 64, 128, 256, 512, 1024, 1536] | |
for d in dimensions: | |
# Vector Dataset | |
xb = np.random.normal(size = (nb, d)).astype('float32') | |
xb = xb / np.linalg.norm(xb, axis=1)[:, np.newaxis] | |
# Query Dataset | |
xq = np.random.normal(size = (nq, d)).astype('float32') | |
xq = xq / np.linalg.norm(xq, axis=1)[:, np.newaxis] | |
# Define the index | |
quantizer = faiss.IndexFlatL2(d) # the other index | |
index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2) | |
# Train the index | |
assert not index.is_trained | |
index.train(xb) | |
assert index.is_trained | |
# Load dataset into the index | |
index.add(xb) | |
# Collect the accurate results by setting nprobes == nlists | |
index.nprobe = nlist | |
# The accurate results for each query vector | |
_, correct_results = index.search(xq, k=10) | |
# For each number of probesm calculate precision@k | |
for nprobes in range(1,30): | |
index.nprobe = nprobes | |
# Results for the given number of nprobes | |
_, q_results = index.search(xq, k=k) | |
# Compare each row of the correct results with the query results | |
agg = [] | |
for correct_row, q_row in zip(correct_results, q_results): | |
correct_row = set(correct_row) | |
q_row = set(q_row) | |
row_precision = float(len(q_row.intersection(correct_row))) / k | |
agg.append(row_precision) | |
# Add result to the output (dimensions, nprobes, precision) | |
rec = (d, nprobes, sum(agg) / len(agg)) | |
print(rec) | |
precision.append(rec) | |
# Load dataset into a dataframe for plotting | |
df = pd.DataFrame(precision, columns=['vector_dim', 'probes', 'precision']) | |
# Plot the results | |
sns.lineplot( | |
data=df, | |
x='probes', | |
y='precision', | |
hue='vector_dim' | |
).set(title='precision@k=10 by probes') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment