Code for the Medium post Link
Last active
February 27, 2022 12:32
-
-
Save stephenleo/521447cd041bca73bedc7b136b8e8e41 to your computer and use it in GitHub Desktop.
[Medium] Powering Semantic Similarity Search in Computer Vision with State of the Art Embeddings
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Create a directory for notebooks and another to download data | |
mkdir -p semantic_similarity/notebooks semantic_similarity/data/cv | |
# CD into the data directory | |
cd semantic_similarity/data/cv | |
# Create and activate a conda environment | |
conda create -n semantic_similarity python=3.8 | |
conda activate semantic_similarity | |
## Create Virtual Environment using venv if not using conda | |
# python -m venv semantic_similarity | |
# source semantic_similarity/bin/activate | |
# Pip install the necessary libraries | |
pip install jupyterlab kaggle pandas matplotlib scikit-learn tqdm ipywidgets | |
# Download data using the kaggle API | |
kaggle datasets download -d masouduut94/digikala-color-classification | |
# Unzip the data into a fashion/ directory | |
unzip digikala-color-classification.zip -d ./fashion | |
# Delete the Zip file | |
rm digikala-color-classification.zip |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from matplotlib import pyplot as plt | |
import numpy as np | |
import os | |
import pandas as pd | |
from PIL import Image | |
from random import randint | |
import shutil | |
from sklearn.metrics.pairwise import cosine_similarity | |
import sys | |
from tqdm import tqdm | |
tqdm.pandas() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def move_to_root_folder(root_path, cur_path): | |
# Code from https://stackoverflow.com/questions/8428954/move-child-folder-contents-to-parent-folder-in-python | |
for filename in os.listdir(cur_path): | |
if os.path.isfile(os.path.join(cur_path, filename)): | |
shutil.move(os.path.join(cur_path, filename), os.path.join(root_path, filename)) | |
elif os.path.isdir(os.path.join(cur_path, filename)): | |
move_to_root_folder(root_path, os.path.join(cur_path, filename)) | |
else: | |
sys.exit("Should never reach here.") | |
# remove empty folders | |
if cur_path != root_path: | |
os.rmdir(cur_path) | |
move_to_root_folder(root_path='../data/cv/fashion', cur_path='../data/cv/fashion') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Path to all the downloaded images | |
img_path = '../data/cv/fashion' | |
# Find list of all files in the path | |
images = [ | |
f'../data/cv/fashion/{f}' | |
for f in os.listdir(img_path) | |
if os.path.isfile(os.path.join(img_path, f)) | |
] | |
# Load the file names to a dataframe | |
image_df = pd.DataFrame(images, columns=['img_path']) | |
print(image_df.shape) | |
image_df.head() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def flatten_pixels(img_path): | |
# Load the image onto python | |
img = Image.open(img_path).convert('RGB') | |
# Reshape the image to 1D and normalize the values | |
flattened_pixels = np.array(img).reshape(-1)/255. | |
return flattened_pixels | |
# Apply the transformation to the dataframe | |
# Warning! Running only on a subset 1K rows of the data, | |
# Your computer might crash if you run on the entire dataset! | |
# Better don’t run it. We have much better ways to generate embeddings! | |
pixels_df = image_df.sample(1_000).reset_index(drop=True).copy() | |
pixels_df['flattened_pixels'] = pixels_df['img_path'].progress_apply(flatten_pixels) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Activate the conda environment if not already done so | |
# conda activate semantic_similarity | |
pip install towhee torch torchvision |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from towhee import pipeline | |
embedding_pipeline = pipeline('image-embedding') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
image_df['towhee_img_embedding'] = image_df['img_path'].progress_apply(lambda x: np.squeeze(embedding_pipeline(x))) | |
image_df.head() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def plot_similar(df, embedding_col, query_index, k_neighbors=5): | |
'''Helper function to take a dataframe index as input query and display the k nearest neighbors | |
''' | |
# Calculate pairwise cosine similarities between query and all rows | |
similarities = cosine_similarity([df[embedding_col][query_index]], df[embedding_col].values.tolist())[0] | |
# Find nearest neighbor indices | |
k = k_neighbors+1 | |
nearest_indices = np.argpartition(similarities, -k)[-k:] | |
nearest_indices = nearest_indices[nearest_indices != query_index] | |
# Plot input image | |
img = Image.open(df['img_path'][query_index]).convert('RGB') | |
plt.imshow(img) | |
plt.title(f'Query Product.\nIndex: {query_index}') | |
# Plot nearest neighbors images | |
fig = plt.figure(figsize=(20,4)) | |
plt.suptitle('Similar Products') | |
for idx, neighbor in enumerate(nearest_indices): | |
plt.subplot(1, len(nearest_indices), idx+1) | |
img = Image.open(df['img_path'][neighbor]).convert('RGB') | |
plt.imshow(img) | |
plt.title(f'Cosine Sim: {similarities[neighbor]:.3f}') | |
plt.tight_layout() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
plot_similar(df=image_df, | |
embedding_col='towhee_img_embedding', | |
query_index=randint(0, len(image_df)), # Query a random image | |
k_neighbors=5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Activate the conda environment if not already done so | |
# conda activate semantic_similarity | |
pip install lightning-bolts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pl_bolts.models.self_supervised import SimCLR | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import io, transforms | |
# Use GPU if it is available | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# load resnet50 pre-trained using SimCLR on imagenet | |
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt' | |
simclr = SimCLR.load_from_checkpoint(weight_path, strict=False, batch_size=32) | |
# Send the SimCLR encoder to the device and set it to eval | |
simclr_resnet50 = simclr.encoder.to(device) | |
simclr_resnet50.eval(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Create a dataset for Pytorch | |
class FashionImageDataset(Dataset): | |
def __init__(self, df, transform=None): | |
self.df = df | |
self.transform = transform | |
def __len__(self): | |
return len(self.df) | |
def __getitem__(self, idx): | |
# Load the Image | |
img_path = self.df.loc[idx, 'img_path'] | |
image = io.read_image(img_path, mode=io.image.ImageReadMode.RGB)/255. | |
# Apply Transformations | |
if self.transform: | |
image = self.transform(image) | |
return image | |
# Transforms | |
## Normalize transform to ensure the images have similar intensity distributions as ImageNet | |
## Resize transform to ensure all images in a batch have the same size | |
transformations = transforms.Compose([ | |
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
transforms.Resize(size=(64, 64)) | |
]) | |
# Create the DataLoader to load images in batches | |
emb_dataset = FashionImageDataset(df=image_df, transform=transformations) | |
emb_dataloader = DataLoader(emb_dataset, batch_size=32) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Create embeddings | |
embeddings = [] | |
for batch in tqdm(emb_dataloader): | |
batch = batch.to(device) | |
embeddings += simclr_resnet50(batch)[0].tolist() | |
# Assign embeddings to a column in the dataframe | |
image_df['simclr_img_embeddings'] = embeddings | |
image_df.head() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
plot_similar(df=image_df, | |
embedding_col='simclr_img_embeddings', | |
query_index=randint(0, len(image_df)), | |
k_neighbors=5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Activate the conda environment if not already done so | |
# conda activate semantic_similarity | |
pip install sentence_transformers ftfy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sentence_transformers import SentenceTransformer | |
model = SentenceTransformer('clip-ViT-B-32') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Initialize an empty list to collect embeddings | |
clip_embeddings = [] | |
# Generate embeddings for 10_000 images on each iteration | |
step = 10_000 | |
for idx in range(0, len(image_df), step): | |
# Load the `step` number of images | |
images = [ | |
Image.open(img_path).convert('RGB') | |
for img_path in image_df['img_path'].iloc[idx:idx+step] | |
] | |
# Generate CLIP embeddings for the loaded images | |
clip_embeddings.extend(model.encode(images, show_progress_bar=True).tolist()) | |
# Assign the embeddings back to the dataframe | |
image_df['clip_img_embedding'] = clip_embeddings | |
image_df.head() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
plot_similar(df=image_df, | |
embedding_col='clip_img_embedding', | |
query_index=randint(0, len(image_df)), | |
k_neighbors=5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def text_image_search(text_query, df, img_emb_col, k=5): | |
'''Helper function to take a text query as input and display the k nearest neighbor images | |
''' | |
# Calculate the text embeddings | |
text_emb = model.encode(text_query).tolist() | |
# Calculate the pairwise cosine similarities between text query and images from all rows | |
similarities = cosine_similarity([text_emb], df[img_emb_col].values.tolist())[0] | |
# Find nearest neighbors | |
nearest_indices = np.argpartition(similarities, -k)[-k:] | |
# Print Query Text | |
print(f'Query Text: {text_query}') | |
# Plot nearest neighbors images | |
fig = plt.figure(figsize=(20,4)) | |
plt.suptitle('Similar Products') | |
for idx, neighbor in enumerate(nearest_indices): | |
plt.subplot(1, len(nearest_indices), idx+1) | |
img = Image.open(df['img_path'][neighbor]).convert('RGB') | |
plt.imshow(img) | |
plt.title(f'Cosine Sim: {similarities[neighbor]:.3f}') | |
plt.tight_layout() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
text_query = "a photo of a women's dress" | |
text_image_search(text_query, | |
df=image_df, | |
img_emb_col='clip_img_embedding', | |
k=5) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment