Last active
December 29, 2023 12:42
-
-
Save niranjanakella/6c0b6d0b8e696e7c381d54b5b643ac40 to your computer and use it in GitHub Desktop.
Exploring Personalized Shopping Experiences with Qdrant's Discovery API
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import streamlit as st | |
from sentence_transformers import SentenceTransformer | |
from qdrant_client import models, QdrantClient | |
#instantiate qdrant client | |
print("[INFO] Client created...") | |
qdrant = QdrantClient("localhost", port=6333) | |
#importing the GTE model | |
print("[INFO] Loading encoder model...") | |
encoder = SentenceTransformer('thenlper/gte-small') | |
print("[INFO] Running Streamlit Application!! ") | |
Gadget = ["Apple Macbook", "USB Drive", "Hard Drive", | |
"Intel Processor", "Memory Card", "Graphics Card"] | |
def main(): | |
st.title("Discovery API for Personalized Shopping Experience") | |
st.subheader("By Niranjan Akella") | |
# multi-select checkbox for gadgets | |
choice_of_gadget = st.multiselect( | |
"Pick gadgets you like", | |
Gadget, | |
default = ["Intel Processor", "Graphics Card"], | |
key="category1" | |
) | |
disliked_gadgets = list(set(Gadget) - set(choice_of_gadget)) | |
# "Personalize" button to trigger a function | |
if st.button("Personalize"): | |
personalize_function(choice_of_gadget, disliked_gadgets) | |
#personalize fn to perform discovery search | |
def personalize_function(choice_of_gadget, disliked_gadgets): | |
#encode choices and group context pairs | |
contexts = [models.ContextExamplePair(positive=encoder.encode(l).tolist(), negative=encoder.encode(d).tolist()) for (l,d) in list(zip(choice_of_gadget, disliked_gadgets))] | |
discovered_products = qdrant.discover(collection_name='e-shopping', context=contexts, limit=3) | |
st.write("Top Recommended Products:") | |
for product in discovered_products: | |
st.write(f"Title: {product.payload['title_left']}") | |
st.write(product.payload['description_left']) | |
st.write("\n") | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer | |
from qdrant_client import models, QdrantClient | |
from tqdm import tqdm | |
''' | |
If you face any SSL issues while downloading the dataset or model | |
uncomment the follwoing to by-pass curl certificate verification | |
''' | |
# import os | |
# os.environ['CURL_CA_BUNDLE'] = '' | |
#instantiate qdrant client | |
print("[INFO] Client created...") | |
qdrant = QdrantClient("localhost", port=6333) | |
#download dataset from hugging face | |
print("[INFO] Loading dataset...") | |
dataset = load_dataset("wdc/products-2017", split='train') | |
#process dataset | |
print("[INFO] Processing dataset...") | |
data = [] | |
fields = ['title_left', 'description_left'] | |
for i in tqdm(range(len(dataset)), total=len(dataset)): | |
if dataset[i]['description_left']: | |
data.append({field:dataset[i][field] for field in fields}) | |
#importing the GTE model | |
print("[INFO] Loading encoder model...") | |
encoder = SentenceTransformer('thenlper/gte-small') | |
#creating data collection in qdrant | |
print("[INFO] Creating a data collection...") | |
qdrant.recreate_collection( | |
collection_name="e-shopping", | |
vectors_config=models.VectorParams( | |
size=encoder.get_sentence_embedding_dimension(), # Vector size is defined by used model | |
distance=models.Distance.COSINE, | |
), | |
) | |
#uploading vectors to data collection | |
print("[INFO] Uploading data to data collection...") | |
records = [] | |
for idx, sample in tqdm(enumerate(data), total=len(data)): | |
if sample['description_left']: | |
records.append(models.Record( | |
id=idx, vector=encoder.encode(sample["description_left"]).tolist(), payload=sample | |
) ) | |
qdrant.upload_records( | |
collection_name="e-shopping", | |
records=records, | |
) | |
print("[INFO] Successfully uploaded data to datacollection!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
datasets==2.11.0 | |
qdrant-client==1.7.0 | |
sentence_transformers==2.2.2 | |
tqdm==4.65.0 | |
sentencepiece==0.1.99 | |
transformers==4.36.2 | |
streamlit==1.29.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment