Skip to content

Instantly share code, notes, and snippets.

@niranjanakella
Last active December 29, 2023 12:42
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save niranjanakella/6c0b6d0b8e696e7c381d54b5b643ac40 to your computer and use it in GitHub Desktop.
Save niranjanakella/6c0b6d0b8e696e7c381d54b5b643ac40 to your computer and use it in GitHub Desktop.
Exploring Personalized Shopping Experiences with Qdrant's Discovery API
import streamlit as st
from sentence_transformers import SentenceTransformer
from qdrant_client import models, QdrantClient
#instantiate qdrant client
print("[INFO] Client created...")
qdrant = QdrantClient("localhost", port=6333)
#importing the GTE model
print("[INFO] Loading encoder model...")
encoder = SentenceTransformer('thenlper/gte-small')
print("[INFO] Running Streamlit Application!! ")
Gadget = ["Apple Macbook", "USB Drive", "Hard Drive",
"Intel Processor", "Memory Card", "Graphics Card"]
def main():
st.title("Discovery API for Personalized Shopping Experience")
st.subheader("By Niranjan Akella")
# multi-select checkbox for gadgets
choice_of_gadget = st.multiselect(
"Pick gadgets you like",
Gadget,
default = ["Intel Processor", "Graphics Card"],
key="category1"
)
disliked_gadgets = list(set(Gadget) - set(choice_of_gadget))
# "Personalize" button to trigger a function
if st.button("Personalize"):
personalize_function(choice_of_gadget, disliked_gadgets)
#personalize fn to perform discovery search
def personalize_function(choice_of_gadget, disliked_gadgets):
#encode choices and group context pairs
contexts = [models.ContextExamplePair(positive=encoder.encode(l).tolist(), negative=encoder.encode(d).tolist()) for (l,d) in list(zip(choice_of_gadget, disliked_gadgets))]
discovered_products = qdrant.discover(collection_name='e-shopping', context=contexts, limit=3)
st.write("Top Recommended Products:")
for product in discovered_products:
st.write(f"Title: {product.payload['title_left']}")
st.write(product.payload['description_left'])
st.write("\n")
if __name__ == "__main__":
main()
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from qdrant_client import models, QdrantClient
from tqdm import tqdm
'''
If you face any SSL issues while downloading the dataset or model
uncomment the follwoing to by-pass curl certificate verification
'''
# import os
# os.environ['CURL_CA_BUNDLE'] = ''
#instantiate qdrant client
print("[INFO] Client created...")
qdrant = QdrantClient("localhost", port=6333)
#download dataset from hugging face
print("[INFO] Loading dataset...")
dataset = load_dataset("wdc/products-2017", split='train')
#process dataset
print("[INFO] Processing dataset...")
data = []
fields = ['title_left', 'description_left']
for i in tqdm(range(len(dataset)), total=len(dataset)):
if dataset[i]['description_left']:
data.append({field:dataset[i][field] for field in fields})
#importing the GTE model
print("[INFO] Loading encoder model...")
encoder = SentenceTransformer('thenlper/gte-small')
#creating data collection in qdrant
print("[INFO] Creating a data collection...")
qdrant.recreate_collection(
collection_name="e-shopping",
vectors_config=models.VectorParams(
size=encoder.get_sentence_embedding_dimension(), # Vector size is defined by used model
distance=models.Distance.COSINE,
),
)
#uploading vectors to data collection
print("[INFO] Uploading data to data collection...")
records = []
for idx, sample in tqdm(enumerate(data), total=len(data)):
if sample['description_left']:
records.append(models.Record(
id=idx, vector=encoder.encode(sample["description_left"]).tolist(), payload=sample
) )
qdrant.upload_records(
collection_name="e-shopping",
records=records,
)
print("[INFO] Successfully uploaded data to datacollection!")
datasets==2.11.0
qdrant-client==1.7.0
sentence_transformers==2.2.2
tqdm==4.65.0
sentencepiece==0.1.99
transformers==4.36.2
streamlit==1.29.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment