Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save stephenleo/521447cd041bca73bedc7b136b8e8e41 to your computer and use it in GitHub Desktop.
Save stephenleo/521447cd041bca73bedc7b136b8e8e41 to your computer and use it in GitHub Desktop.
[Medium] Powering Semantic Similarity Search in Computer Vision with State of the Art Embeddings

Powering Semantic Similarity Search in Computer Vision with State of the Art Embeddings

Code for the Medium post Link

# Create a directory for notebooks and another to download data
mkdir -p semantic_similarity/notebooks semantic_similarity/data/cv
# CD into the data directory
cd semantic_similarity/data/cv
# Create and activate a conda environment
conda create -n semantic_similarity python=3.8
conda activate semantic_similarity
## Create Virtual Environment using venv if not using conda
# python -m venv semantic_similarity
# source semantic_similarity/bin/activate
# Pip install the necessary libraries
pip install jupyterlab kaggle pandas matplotlib scikit-learn tqdm ipywidgets
# Download data using the kaggle API
kaggle datasets download -d masouduut94/digikala-color-classification
# Unzip the data into a fashion/ directory
unzip digikala-color-classification.zip -d ./fashion
# Delete the Zip file
rm digikala-color-classification.zip
from matplotlib import pyplot as plt
import numpy as np
import os
import pandas as pd
from PIL import Image
from random import randint
import shutil
from sklearn.metrics.pairwise import cosine_similarity
import sys
from tqdm import tqdm
tqdm.pandas()
def move_to_root_folder(root_path, cur_path):
# Code from https://stackoverflow.com/questions/8428954/move-child-folder-contents-to-parent-folder-in-python
for filename in os.listdir(cur_path):
if os.path.isfile(os.path.join(cur_path, filename)):
shutil.move(os.path.join(cur_path, filename), os.path.join(root_path, filename))
elif os.path.isdir(os.path.join(cur_path, filename)):
move_to_root_folder(root_path, os.path.join(cur_path, filename))
else:
sys.exit("Should never reach here.")
# remove empty folders
if cur_path != root_path:
os.rmdir(cur_path)
move_to_root_folder(root_path='../data/cv/fashion', cur_path='../data/cv/fashion')
# Path to all the downloaded images
img_path = '../data/cv/fashion'
# Find list of all files in the path
images = [
f'../data/cv/fashion/{f}'
for f in os.listdir(img_path)
if os.path.isfile(os.path.join(img_path, f))
]
# Load the file names to a dataframe
image_df = pd.DataFrame(images, columns=['img_path'])
print(image_df.shape)
image_df.head()
def flatten_pixels(img_path):
# Load the image onto python
img = Image.open(img_path).convert('RGB')
# Reshape the image to 1D and normalize the values
flattened_pixels = np.array(img).reshape(-1)/255.
return flattened_pixels
# Apply the transformation to the dataframe
# Warning! Running only on a subset 1K rows of the data,
# Your computer might crash if you run on the entire dataset!
# Better don’t run it. We have much better ways to generate embeddings!
pixels_df = image_df.sample(1_000).reset_index(drop=True).copy()
pixels_df['flattened_pixels'] = pixels_df['img_path'].progress_apply(flatten_pixels)
# Activate the conda environment if not already done so
# conda activate semantic_similarity
pip install towhee torch torchvision
from towhee import pipeline
embedding_pipeline = pipeline('image-embedding')
image_df['towhee_img_embedding'] = image_df['img_path'].progress_apply(lambda x: np.squeeze(embedding_pipeline(x)))
image_df.head()
def plot_similar(df, embedding_col, query_index, k_neighbors=5):
'''Helper function to take a dataframe index as input query and display the k nearest neighbors
'''
# Calculate pairwise cosine similarities between query and all rows
similarities = cosine_similarity([df[embedding_col][query_index]], df[embedding_col].values.tolist())[0]
# Find nearest neighbor indices
k = k_neighbors+1
nearest_indices = np.argpartition(similarities, -k)[-k:]
nearest_indices = nearest_indices[nearest_indices != query_index]
# Plot input image
img = Image.open(df['img_path'][query_index]).convert('RGB')
plt.imshow(img)
plt.title(f'Query Product.\nIndex: {query_index}')
# Plot nearest neighbors images
fig = plt.figure(figsize=(20,4))
plt.suptitle('Similar Products')
for idx, neighbor in enumerate(nearest_indices):
plt.subplot(1, len(nearest_indices), idx+1)
img = Image.open(df['img_path'][neighbor]).convert('RGB')
plt.imshow(img)
plt.title(f'Cosine Sim: {similarities[neighbor]:.3f}')
plt.tight_layout()
plot_similar(df=image_df,
embedding_col='towhee_img_embedding',
query_index=randint(0, len(image_df)), # Query a random image
k_neighbors=5)
# Activate the conda environment if not already done so
# conda activate semantic_similarity
pip install lightning-bolts
from pl_bolts.models.self_supervised import SimCLR
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import io, transforms
# Use GPU if it is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load resnet50 pre-trained using SimCLR on imagenet
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt'
simclr = SimCLR.load_from_checkpoint(weight_path, strict=False, batch_size=32)
# Send the SimCLR encoder to the device and set it to eval
simclr_resnet50 = simclr.encoder.to(device)
simclr_resnet50.eval();
# Create a dataset for Pytorch
class FashionImageDataset(Dataset):
def __init__(self, df, transform=None):
self.df = df
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# Load the Image
img_path = self.df.loc[idx, 'img_path']
image = io.read_image(img_path, mode=io.image.ImageReadMode.RGB)/255.
# Apply Transformations
if self.transform:
image = self.transform(image)
return image
# Transforms
## Normalize transform to ensure the images have similar intensity distributions as ImageNet
## Resize transform to ensure all images in a batch have the same size
transformations = transforms.Compose([
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
transforms.Resize(size=(64, 64))
])
# Create the DataLoader to load images in batches
emb_dataset = FashionImageDataset(df=image_df, transform=transformations)
emb_dataloader = DataLoader(emb_dataset, batch_size=32)
# Create embeddings
embeddings = []
for batch in tqdm(emb_dataloader):
batch = batch.to(device)
embeddings += simclr_resnet50(batch)[0].tolist()
# Assign embeddings to a column in the dataframe
image_df['simclr_img_embeddings'] = embeddings
image_df.head()
plot_similar(df=image_df,
embedding_col='simclr_img_embeddings',
query_index=randint(0, len(image_df)),
k_neighbors=5)
# Activate the conda environment if not already done so
# conda activate semantic_similarity
pip install sentence_transformers ftfy
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('clip-ViT-B-32')
# Initialize an empty list to collect embeddings
clip_embeddings = []
# Generate embeddings for 10_000 images on each iteration
step = 10_000
for idx in range(0, len(image_df), step):
# Load the `step` number of images
images = [
Image.open(img_path).convert('RGB')
for img_path in image_df['img_path'].iloc[idx:idx+step]
]
# Generate CLIP embeddings for the loaded images
clip_embeddings.extend(model.encode(images, show_progress_bar=True).tolist())
# Assign the embeddings back to the dataframe
image_df['clip_img_embedding'] = clip_embeddings
image_df.head()
plot_similar(df=image_df,
embedding_col='clip_img_embedding',
query_index=randint(0, len(image_df)),
k_neighbors=5)
def text_image_search(text_query, df, img_emb_col, k=5):
'''Helper function to take a text query as input and display the k nearest neighbor images
'''
# Calculate the text embeddings
text_emb = model.encode(text_query).tolist()
# Calculate the pairwise cosine similarities between text query and images from all rows
similarities = cosine_similarity([text_emb], df[img_emb_col].values.tolist())[0]
# Find nearest neighbors
nearest_indices = np.argpartition(similarities, -k)[-k:]
# Print Query Text
print(f'Query Text: {text_query}')
# Plot nearest neighbors images
fig = plt.figure(figsize=(20,4))
plt.suptitle('Similar Products')
for idx, neighbor in enumerate(nearest_indices):
plt.subplot(1, len(nearest_indices), idx+1)
img = Image.open(df['img_path'][neighbor]).convert('RGB')
plt.imshow(img)
plt.title(f'Cosine Sim: {similarities[neighbor]:.3f}')
plt.tight_layout()
text_query = "a photo of a women's dress"
text_image_search(text_query,
df=image_df,
img_emb_col='clip_img_embedding',
k=5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment