Skip to content

Instantly share code, notes, and snippets.

Created July 19, 2024 11:09
Script to perform image similarity search using pgvector
import os
import psycopg2
import numpy as np
from PIL import Image
from scipy.spatial.distance import cosine
import torch
from torchvision import models, transforms
# Database configuration
DB_NAME = "database"
DB_USER = "user"
DB_PASSWORD = "password"
DB_HOST = "host"
DB_PORT = "port"
# Load pre-trained ResNet model
model = models.resnet50(pretrained=True)
model.eval() # Set to evaluation mode
# Remove the final classification layer to get embeddings
model = torch.nn.Sequential(*list(model.children())[:-1])
# Image preprocessing
preprocess = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
def get_image_embedding(image_path):
image ="RGB")
image_tensor = preprocess(image).unsqueeze(0)
with torch.no_grad():
embedding = model(image_tensor)
return embedding.squeeze().numpy()
def find_most_similar_images(query_embedding, top_n=3):
conn = psycopg2.connect(
dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD, host=DB_HOST, port=DB_PORT
cursor = conn.cursor()
# Convert query_embedding to a list to store in PostgreSQL
query_embedding_list = query_embedding.tolist()
qry = f"SELECT image_path, embedding <-> '{query_embedding_list}' AS distance FROM image_embeddings ORDER BY embedding <-> '{query_embedding_list}' LIMIT {top_n};"
# Execute the similarity search query
similar_images = cursor.fetchall()
return similar_images
def main(query_image_path):
query_embedding = get_image_embedding(query_image_path)
similar_images = find_most_similar_images(query_embedding)
print("Most similar images:")
for image_path, similarity in similar_images:
print(f"Image: {image_path}, Distance: {similarity}")
if __name__ == "__main__":
# Replace with the path to the query image
query_image_path = "path/to/your/query/image.jpg"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment