Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save stephenleo/12e9540843af00293a1ed7fb71641b34 to your computer and use it in GitHub Desktop.
Save stephenleo/12e9540843af00293a1ed7fb71641b34 to your computer and use it in GitHub Desktop.
[Medium] Supercharged Semantic Similarity Search in Production

Supercharged Semantic Similarity Search in Production

Code for the Medium post Link

# Create the necessary directories
mkdir -p milvus_image_search/notebooks milvus_image_search/data milvus_image_search/milvus
# CD into the data directory
cd milvus_image_search/data
# Create and activate a conda environment
conda create -n image_sim python=3.8 -y
conda activate image_sim
## Create Virtual Environment using venv if not using conda
# python -m venv image_sim
# source image_sim/bin/activate
# Pip install the necessary libraries
pip install jupyterlab kaggle matplotlib scikit-learn tqdm ipywidgets
pip install pandas==1.3.5 pymilvus==2.0.0
pip install sentence_transformers ftfy
# Download data using the kaggle API
kaggle competitions download -c h-and-m-personalized-fashion-recommendations
# Unzip the data into the local directory
# Delete the Zip file
from matplotlib import pyplot as plt
import pandas as pd
from pathlib import Path
from PIL import Image
from sklearn.preprocessing import normalize
import time
from tqdm import tqdm
# Path to all the downloaded images
img_path = Path('../data/images')
# Find list of all files in the path
images = [path for path in img_path.glob('**/*.jpg')]
# Load the file names to a dataframe
image_df = pd.DataFrame(images, columns=['img_path'])
image_df['article_id'] = image_df['img_path'].apply(lambda x: int(x.stem))
# Create a product mapping dict with product names
articles_df = pd.read_csv('../data/articles.csv')
product_mapping = image_df.merge(articles_df[['article_id', 'prod_name']],
product_mapping = product_mapping.set_index('article_id')
product_mapping = product_mapping.to_dict(orient='index')
# CD into the milvus directory
cd ../milvus
# Download the docker-compose.yml
wget -O docker-compose.yml
# Start the milvus standalone server
sudo docker-compose up -d
# Make sure a Milvus server is already running
from pymilvus import connections, utility
from pymilvus import Collection, CollectionSchema, FieldSchema, DataType
# Connect to Milvus server
connections.connect(alias='default', host='localhost', port='19530')
# Collection name
collection_name = 'hnm_fashion_images'
# Embedding size
emb_dim = 512
## Check for existing collection and drop if exists
# if utility.has_collection(collection_name):
# print(utility.list_collections())
# utility.drop_collection(collection_name)
# Create a schema for the collection
article_id = FieldSchema(name='article_id', dtype=DataType.INT64, is_primary=True)
img_embedding = FieldSchema(name='img_embedding', dtype=DataType.FLOAT_VECTOR, dim=emb_dim)
# Set the Collection Schema
fields = [article_id, img_embedding]
schema = CollectionSchema(fields=fields, description='H&M Fashion Products')
# Create a collection with the schema
collection = Collection(name=collection_name, schema=schema, using='default', shards_num=10)
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('clip-ViT-B-32')
# Generate embeddings for batches of images on each iteration
batch_size = 512
clip_embeddings = []
for idx in tqdm(range(0, len(image_df), batch_size)):
subset_df = image_df.iloc[idx:idx+batch_size]
# Primary Key
data = [subset_df['article_id'].values.tolist()]
# Embedding
## Load the `batch_size` number of images
images = ['RGB')
for img_path in subset_df['img_path']
## Generate CLIP embeddings for the loaded images
raw_embeddings = model.encode(images)
## Normalize the embeddings to use IP distance
norm_embeddings = normalize(raw_embeddings, axis=1).tolist()
## Append the embeddings
clip_embeddings += norm_embeddings
# Insert data to milvus
# Save the embeddings to a file
image_df['clip_embeddings'] = clip_embeddings
# Add an ANN index to the collection
index_params = {
"params":{"nlist":1024, "m":8}
collection.create_index(field_name='img_embedding', index_params=index_params)
# Load the collection into memory
collection = Collection(collection_name)
def query_and_display(query_text, collection, product_mapping, num_results=10):
# Embed the Query Text
raw_embeddings = [model.encode(query_text)]
## Normalize the embeddings to use Cosine Similarity
query_emb = normalize(raw_embeddings, axis=1).tolist()
# Search Params
search_params = {"metric_type": "IP", "params": {"nprobe": 20}}
# Text to Image Milvus ANN Search
query_start = time.time()
results =,
query_end = time.time()
# Convert search results to products
result_products = [product_mapping[item] for item in results[0].ids]
result_similarities = results[0].distances
# Plot search results
ncols = 5
nrows = -(-len(result_products)//ncols)
fig = plt.figure(figsize=(20,5*nrows))
plt.suptitle('Search results')
for idx, product in enumerate(result_products):
plt.subplot(nrows, ncols, idx+1)
img =['img_path']).convert('RGB')
plt.title(f'Product Name: {product["prod_name"]}\nCosine Similarity:{result_similarities[idx]:.3f}')
plt.suptitle(f'Query Text: {query_text}. Query returned in {(query_end-query_start):.3f} seconds')
# Query for products that match "floral top" search term
query_and_display('floral top', collection, product_mapping, num_results=10)
# Release the collection from memory when it's not needed anymore
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment