Code for the Medium post Link
Last active
April 1, 2024 12:45
-
-
Save stephenleo/12e9540843af00293a1ed7fb71641b34 to your computer and use it in GitHub Desktop.
[Medium] Supercharged Semantic Similarity Search in Production
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Create the necessary directories | |
mkdir -p milvus_image_search/notebooks milvus_image_search/data milvus_image_search/milvus | |
# CD into the data directory | |
cd milvus_image_search/data | |
# Create and activate a conda environment | |
conda create -n image_sim python=3.8 -y | |
conda activate image_sim | |
## Create Virtual Environment using venv if not using conda | |
# python -m venv image_sim | |
# source image_sim/bin/activate | |
# Pip install the necessary libraries | |
pip install jupyterlab kaggle matplotlib scikit-learn tqdm ipywidgets | |
pip install pandas==1.3.5 pymilvus==2.0.0 | |
pip install sentence_transformers ftfy | |
# Download data using the kaggle API | |
kaggle competitions download -c h-and-m-personalized-fashion-recommendations | |
# Unzip the data into the local directory | |
unzip h-and-m-personalized-fashion-recommendations.zip | |
# Delete the Zip file | |
rm h-and-m-personalized-fashion-recommendations.zip |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from matplotlib import pyplot as plt | |
import pandas as pd | |
from pathlib import Path | |
from PIL import Image | |
from sklearn.preprocessing import normalize | |
import time | |
from tqdm import tqdm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Path to all the downloaded images | |
img_path = Path('../data/images') | |
# Find list of all files in the path | |
images = [path for path in img_path.glob('**/*.jpg')] | |
# Load the file names to a dataframe | |
image_df = pd.DataFrame(images, columns=['img_path']) | |
image_df['article_id'] = image_df['img_path'].apply(lambda x: int(x.stem)) | |
print(image_df.shape) | |
image_df.head() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Create a product mapping dict with product names | |
articles_df = pd.read_csv('../data/articles.csv') | |
product_mapping = image_df.merge(articles_df[['article_id', 'prod_name']], | |
on='article_id') | |
product_mapping = product_mapping.set_index('article_id') | |
product_mapping = product_mapping.to_dict(orient='index') | |
print(product_mapping[554541045]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# CD into the milvus directory | |
cd ../milvus | |
# Download the docker-compose.yml | |
wget https://github.com/milvus-io/milvus/releases/download/v2.0.0/milvus-standalone-docker-compose.yml -O docker-compose.yml | |
# Start the milvus standalone server | |
sudo docker-compose up -d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Make sure a Milvus server is already running | |
from pymilvus import connections, utility | |
from pymilvus import Collection, CollectionSchema, FieldSchema, DataType | |
# Connect to Milvus server | |
connections.connect(alias='default', host='localhost', port='19530') | |
# Collection name | |
collection_name = 'hnm_fashion_images' | |
# Embedding size | |
emb_dim = 512 | |
## Check for existing collection and drop if exists | |
# if utility.has_collection(collection_name): | |
# print(utility.list_collections()) | |
# utility.drop_collection(collection_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Create a schema for the collection | |
article_id = FieldSchema(name='article_id', dtype=DataType.INT64, is_primary=True) | |
img_embedding = FieldSchema(name='img_embedding', dtype=DataType.FLOAT_VECTOR, dim=emb_dim) | |
# Set the Collection Schema | |
fields = [article_id, img_embedding] | |
schema = CollectionSchema(fields=fields, description='H&M Fashion Products') | |
# Create a collection with the schema | |
collection = Collection(name=collection_name, schema=schema, using='default', shards_num=10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sentence_transformers import SentenceTransformer | |
model = SentenceTransformer('clip-ViT-B-32') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Generate embeddings for batches of images on each iteration | |
batch_size = 512 | |
clip_embeddings = [] | |
for idx in tqdm(range(0, len(image_df), batch_size)): | |
subset_df = image_df.iloc[idx:idx+batch_size] | |
# Primary Key | |
data = [subset_df['article_id'].values.tolist()] | |
# Embedding | |
## Load the `batch_size` number of images | |
images = [ | |
Image.open(img_path).convert('RGB') | |
for img_path in subset_df['img_path'] | |
] | |
## Generate CLIP embeddings for the loaded images | |
raw_embeddings = model.encode(images) | |
## Normalize the embeddings to use IP distance | |
## https://milvus.io/docs/v2.0.0/metric.md#Inner-product-IP | |
norm_embeddings = normalize(raw_embeddings, axis=1).tolist() | |
## Append the embeddings | |
data.append(norm_embeddings) | |
clip_embeddings += norm_embeddings | |
# Insert data to milvus | |
collection.insert(data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Save the embeddings to a file | |
image_df['clip_embeddings'] = clip_embeddings | |
image_df.to_pickle('../data/image_embeddings_df.pkl') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Add an ANN index to the collection | |
index_params = { | |
"metric_type":"IP", | |
"index_type":"IVF_PQ", | |
"params":{"nlist":1024, "m":8} | |
} | |
collection.create_index(field_name='img_embedding', index_params=index_params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Load the collection into memory | |
collection = Collection(collection_name) | |
collection.load() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def query_and_display(query_text, collection, product_mapping, num_results=10): | |
# Embed the Query Text | |
raw_embeddings = [model.encode(query_text)] | |
## Normalize the embeddings to use Cosine Similarity | |
## https://milvus.io/docs/v2.0.0/metric.md#Inner-product-IP | |
query_emb = normalize(raw_embeddings, axis=1).tolist() | |
# Search Params | |
search_params = {"metric_type": "IP", "params": {"nprobe": 20}} | |
# Text to Image Milvus ANN Search | |
query_start = time.time() | |
results = collection.search(data=query_emb, | |
anns_field='img_embedding', | |
param=search_params, | |
limit=num_results, | |
expr=None) | |
query_end = time.time() | |
# Convert search results to products | |
result_products = [product_mapping[item] for item in results[0].ids] | |
result_similarities = results[0].distances | |
# Plot search results | |
ncols = 5 | |
nrows = -(-len(result_products)//ncols) | |
fig = plt.figure(figsize=(20,5*nrows)) | |
plt.suptitle('Search results') | |
for idx, product in enumerate(result_products): | |
plt.subplot(nrows, ncols, idx+1) | |
img = Image.open(product['img_path']).convert('RGB') | |
plt.imshow(img) | |
plt.title(f'Product Name: {product["prod_name"]}\nCosine Similarity:{result_similarities[idx]:.3f}') | |
plt.suptitle(f'Query Text: {query_text}. Query returned in {(query_end-query_start):.3f} seconds') | |
plt.tight_layout() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Query for products that match "floral top" search term | |
query_and_display('floral top', collection, product_mapping, num_results=10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Release the collection from memory when it's not needed anymore | |
collection.release() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment