Skip to content

Instantly share code, notes, and snippets.

@mayukh18
Created January 4, 2022 06:37
Show Gist options
  • Save mayukh18/f2aea99e0fb04e7a075bf5fd95f0786b to your computer and use it in GitHub Desktop.
Save mayukh18/f2aea99e0fb04e7a075bf5fd95f0786b to your computer and use it in GitHub Desktop.
#!pip install faiss-gpu
import faiss
faiss_index = faiss.IndexFlatL2(1000) # build the index
# storing the image representations
im_indices = []
with torch.no_grad():
for f in glob.glob(os.path.join(PATH_TRAIN, '*/*')):
im = Image.open(f)
im = im.resize((224,224))
im = torch.tensor([val_transforms(im).numpy()]).cuda()
preds = model(im)
preds = np.array([preds[0].cpu().numpy()])
faiss_index.add(preds) #add the representation to index
im_indices.append(f) #store the image name to find it later on
# Retrieval with a query image
with torch.no_grad():
for f in os.listdir(PATH_TEST):
# query/test image
im = Image.open(os.path.join(PATH_TEST,f))
im = im.resize((224,224))
im = torch.tensor([val_transforms(im).numpy()]).cuda()
test_embed = model(im).cpu().numpy()
_, I = faiss_index.search(test_embed, 5)
print("Retrieved Image: {}".format(im_indices[I[0][0]]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment