-
-
Save niranjanakella/372a89d2f3d82784617f690ca5a28c84 to your computer and use it in GitHub Desktop.
Semantic Search Over Satellite Images Using Qdrant
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gradio as gr | |
from PIL import Image | |
import datasets | |
from transformers import AutoTokenizer, AutoProcessor, AutoModelForZeroShotImageClassification | |
from qdrant_client import QdrantClient | |
from qdrant_client.http import models | |
from tqdm import tqdm | |
import numpy as np | |
client = QdrantClient("localhost", port=6333) | |
print("[INFO] Client created...") | |
#loading the model | |
print("[INFO] Loading the model...") | |
model_name = "openai/clip-vit-base-patch32" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
processor = AutoProcessor.from_pretrained(model_name) | |
model = AutoModelForZeroShotImageClassification.from_pretrained(model_name) | |
def process_text(text): | |
inp = tokenizer(text, return_tensors="pt") | |
text_embeddings = model.get_text_features(**inp).detach().numpy().tolist()[0] | |
hits = client.search( | |
collection_name="satellite_img_db", | |
query_vector=text_embeddings, | |
limit=1, | |
) | |
# images = [] | |
for hit in hits: | |
img_size = tuple(hit.payload['img_size']) | |
pixel_lst = hit.payload['pixel_lst'] | |
# Create an image from pixel data | |
new_image = Image.new("RGB", img_size) | |
new_image.putdata(list(map(lambda x: tuple(x), pixel_lst))) | |
# images.append(new_image) | |
return new_image | |
# Gradio Interface | |
iface = gr.Interface( | |
title="Semantic Search Over Satellite Images Using Qdrant Vector Database", | |
description="by Niranjan Akella", | |
fn=process_text, | |
inputs=gr.Textbox(label="Input prompt"), | |
outputs=gr.Image(type="pil", label="Satellite Image"), | |
) | |
iface.launch() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datasets | |
from transformers import AutoTokenizer, AutoProcessor, AutoModelForZeroShotImageClassification | |
from qdrant_client import QdrantClient | |
from qdrant_client.http import models | |
from tqdm import tqdm | |
import numpy as np | |
client = QdrantClient("localhost", port=6333) | |
print("[INFO] Client created...") | |
#loading the dataset | |
print("[INFO] Loading dataset...") | |
ds = datasets.load_dataset('arampacha/rsicd', split='train') | |
#loading the model | |
print("[INFO] Loading the model...") | |
model_name = "openai/clip-vit-base-patch32" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
processor = AutoProcessor.from_pretrained(model_name) | |
model = AutoModelForZeroShotImageClassification.from_pretrained(model_name) | |
#Creating a qdrant collection in qdrant database to store this image embeddings | |
print("[INFO] Creating qdrant data collection...") | |
client.create_collection( | |
collection_name="satellite_img_db", | |
vectors_config=models.VectorParams(size=512, distance=models.Distance.COSINE), | |
) | |
#creating records/vectors | |
print("[INFO] Creating a data collection...") | |
records = [] | |
for idx, sample in tqdm(enumerate(ds), total=len(ds)): | |
processed_img = processor(text=None, images = sample['image'], return_tensors="pt")['pixel_values'] | |
img_embds = model.get_image_features(processed_img).detach().numpy().tolist()[0] | |
img_px = list(sample['image'].getdata()) | |
img_size = sample['image'].size | |
records.append(models.Record(id=idx, vector=img_embds, payload={"pixel_lst":img_px, "img_size": img_size, "captions": sample['captions']})) | |
#uploading the records to client | |
print("[INFO] Uploading data records to data collection...") | |
#It's better to upload chunks of data to the VectorDB | |
for i in range(30,len(records), 30): | |
print(f"finished {i}") | |
client.upload_records( | |
collection_name="satellite_img_db", | |
records=records[i-30:i], | |
) | |
print("[INFO] Successfully uploaded data records to data collection!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
datasets==2.11.0 | |
qdrant-client==1.7.0 | |
sentence_transformers==2.2.2 | |
tqdm==4.65.0 | |
sentencepiece==0.1.99 | |
transformers==4.36.2 | |
gradio==4.12.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment