Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save stephenleo/12e9540843af00293a1ed7fb71641b34 to your computer and use it in GitHub Desktop.
Save stephenleo/12e9540843af00293a1ed7fb71641b34 to your computer and use it in GitHub Desktop.
[Medium] Supercharged Semantic Similarity Search in Production

Supercharged Semantic Similarity Search in Production

Code for the Medium post Link

# Create the necessary directories
mkdir -p milvus_image_search/notebooks milvus_image_search/data milvus_image_search/milvus
# CD into the data directory
cd milvus_image_search/data
# Create and activate a conda environment
conda create -n image_sim python=3.8 -y
conda activate image_sim
## Create Virtual Environment using venv if not using conda
# python -m venv image_sim
# source image_sim/bin/activate
# Pip install the necessary libraries
pip install jupyterlab kaggle matplotlib scikit-learn tqdm ipywidgets
pip install pandas==1.3.5 pymilvus==2.0.0
pip install sentence_transformers ftfy
# Download data using the kaggle API
kaggle competitions download -c h-and-m-personalized-fashion-recommendations
# Unzip the data into the local directory
unzip h-and-m-personalized-fashion-recommendations.zip
# Delete the Zip file
rm h-and-m-personalized-fashion-recommendations.zip
from matplotlib import pyplot as plt
import pandas as pd
from pathlib import Path
from PIL import Image
from sklearn.preprocessing import normalize
import time
from tqdm import tqdm
# Path to all the downloaded images
img_path = Path('../data/images')
# Find list of all files in the path
images = [path for path in img_path.glob('**/*.jpg')]
# Load the file names to a dataframe
image_df = pd.DataFrame(images, columns=['img_path'])
image_df['article_id'] = image_df['img_path'].apply(lambda x: int(x.stem))
print(image_df.shape)
image_df.head()
# Create a product mapping dict with product names
articles_df = pd.read_csv('../data/articles.csv')
product_mapping = image_df.merge(articles_df[['article_id', 'prod_name']],
on='article_id')
product_mapping = product_mapping.set_index('article_id')
product_mapping = product_mapping.to_dict(orient='index')
print(product_mapping[554541045])
# CD into the milvus directory
cd ../milvus
# Download the docker-compose.yml
wget https://github.com/milvus-io/milvus/releases/download/v2.0.0/milvus-standalone-docker-compose.yml -O docker-compose.yml
# Start the milvus standalone server
sudo docker-compose up -d
# Make sure a Milvus server is already running
from pymilvus import connections, utility
from pymilvus import Collection, CollectionSchema, FieldSchema, DataType
# Connect to Milvus server
connections.connect(alias='default', host='localhost', port='19530')
# Collection name
collection_name = 'hnm_fashion_images'
# Embedding size
emb_dim = 512
## Check for existing collection and drop if exists
# if utility.has_collection(collection_name):
# print(utility.list_collections())
# utility.drop_collection(collection_name)
# Create a schema for the collection
article_id = FieldSchema(name='article_id', dtype=DataType.INT64, is_primary=True)
img_embedding = FieldSchema(name='img_embedding', dtype=DataType.FLOAT_VECTOR, dim=emb_dim)
# Set the Collection Schema
fields = [article_id, img_embedding]
schema = CollectionSchema(fields=fields, description='H&M Fashion Products')
# Create a collection with the schema
collection = Collection(name=collection_name, schema=schema, using='default', shards_num=10)
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('clip-ViT-B-32')
# Generate embeddings for batches of images on each iteration
batch_size = 512
clip_embeddings = []
for idx in tqdm(range(0, len(image_df), batch_size)):
subset_df = image_df.iloc[idx:idx+batch_size]
# Primary Key
data = [subset_df['article_id'].values.tolist()]
# Embedding
## Load the `batch_size` number of images
images = [
Image.open(img_path).convert('RGB')
for img_path in subset_df['img_path']
]
## Generate CLIP embeddings for the loaded images
raw_embeddings = model.encode(images)
## Normalize the embeddings to use IP distance
## https://milvus.io/docs/v2.0.0/metric.md#Inner-product-IP
norm_embeddings = normalize(raw_embeddings, axis=1).tolist()
## Append the embeddings
data.append(norm_embeddings)
clip_embeddings += norm_embeddings
# Insert data to milvus
collection.insert(data)
# Save the embeddings to a file
image_df['clip_embeddings'] = clip_embeddings
image_df.to_pickle('../data/image_embeddings_df.pkl')
# Add an ANN index to the collection
index_params = {
"metric_type":"IP",
"index_type":"IVF_PQ",
"params":{"nlist":1024, "m":8}
}
collection.create_index(field_name='img_embedding', index_params=index_params)
# Load the collection into memory
collection = Collection(collection_name)
collection.load()
def query_and_display(query_text, collection, product_mapping, num_results=10):
# Embed the Query Text
raw_embeddings = [model.encode(query_text)]
## Normalize the embeddings to use Cosine Similarity
## https://milvus.io/docs/v2.0.0/metric.md#Inner-product-IP
query_emb = normalize(raw_embeddings, axis=1).tolist()
# Search Params
search_params = {"metric_type": "IP", "params": {"nprobe": 20}}
# Text to Image Milvus ANN Search
query_start = time.time()
results = collection.search(data=query_emb,
anns_field='img_embedding',
param=search_params,
limit=num_results,
expr=None)
query_end = time.time()
# Convert search results to products
result_products = [product_mapping[item] for item in results[0].ids]
result_similarities = results[0].distances
# Plot search results
ncols = 5
nrows = -(-len(result_products)//ncols)
fig = plt.figure(figsize=(20,5*nrows))
plt.suptitle('Search results')
for idx, product in enumerate(result_products):
plt.subplot(nrows, ncols, idx+1)
img = Image.open(product['img_path']).convert('RGB')
plt.imshow(img)
plt.title(f'Product Name: {product["prod_name"]}\nCosine Similarity:{result_similarities[idx]:.3f}')
plt.suptitle(f'Query Text: {query_text}. Query returned in {(query_end-query_start):.3f} seconds')
plt.tight_layout()
# Query for products that match "floral top" search term
query_and_display('floral top', collection, product_mapping, num_results=10)
# Release the collection from memory when it's not needed anymore
collection.release()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment