Skip to content

Instantly share code, notes, and snippets.

@niranjanakella
Created December 26, 2023 08:03
Show Gist options
  • Save niranjanakella/372a89d2f3d82784617f690ca5a28c84 to your computer and use it in GitHub Desktop.
Save niranjanakella/372a89d2f3d82784617f690ca5a28c84 to your computer and use it in GitHub Desktop.
Semantic Search Over Satellite Images Using Qdrant
import gradio as gr
from PIL import Image
import datasets
from transformers import AutoTokenizer, AutoProcessor, AutoModelForZeroShotImageClassification
from qdrant_client import QdrantClient
from qdrant_client.http import models
from tqdm import tqdm
import numpy as np
client = QdrantClient("localhost", port=6333)
print("[INFO] Client created...")
#loading the model
print("[INFO] Loading the model...")
model_name = "openai/clip-vit-base-patch32"
tokenizer = AutoTokenizer.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForZeroShotImageClassification.from_pretrained(model_name)
def process_text(text):
inp = tokenizer(text, return_tensors="pt")
text_embeddings = model.get_text_features(**inp).detach().numpy().tolist()[0]
hits = client.search(
collection_name="satellite_img_db",
query_vector=text_embeddings,
limit=1,
)
# images = []
for hit in hits:
img_size = tuple(hit.payload['img_size'])
pixel_lst = hit.payload['pixel_lst']
# Create an image from pixel data
new_image = Image.new("RGB", img_size)
new_image.putdata(list(map(lambda x: tuple(x), pixel_lst)))
# images.append(new_image)
return new_image
# Gradio Interface
iface = gr.Interface(
title="Semantic Search Over Satellite Images Using Qdrant Vector Database",
description="by Niranjan Akella",
fn=process_text,
inputs=gr.Textbox(label="Input prompt"),
outputs=gr.Image(type="pil", label="Satellite Image"),
)
iface.launch()
import datasets
from transformers import AutoTokenizer, AutoProcessor, AutoModelForZeroShotImageClassification
from qdrant_client import QdrantClient
from qdrant_client.http import models
from tqdm import tqdm
import numpy as np
client = QdrantClient("localhost", port=6333)
print("[INFO] Client created...")
#loading the dataset
print("[INFO] Loading dataset...")
ds = datasets.load_dataset('arampacha/rsicd', split='train')
#loading the model
print("[INFO] Loading the model...")
model_name = "openai/clip-vit-base-patch32"
tokenizer = AutoTokenizer.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForZeroShotImageClassification.from_pretrained(model_name)
#Creating a qdrant collection in qdrant database to store this image embeddings
print("[INFO] Creating qdrant data collection...")
client.create_collection(
collection_name="satellite_img_db",
vectors_config=models.VectorParams(size=512, distance=models.Distance.COSINE),
)
#creating records/vectors
print("[INFO] Creating a data collection...")
records = []
for idx, sample in tqdm(enumerate(ds), total=len(ds)):
processed_img = processor(text=None, images = sample['image'], return_tensors="pt")['pixel_values']
img_embds = model.get_image_features(processed_img).detach().numpy().tolist()[0]
img_px = list(sample['image'].getdata())
img_size = sample['image'].size
records.append(models.Record(id=idx, vector=img_embds, payload={"pixel_lst":img_px, "img_size": img_size, "captions": sample['captions']}))
#uploading the records to client
print("[INFO] Uploading data records to data collection...")
#It's better to upload chunks of data to the VectorDB
for i in range(30,len(records), 30):
print(f"finished {i}")
client.upload_records(
collection_name="satellite_img_db",
records=records[i-30:i],
)
print("[INFO] Successfully uploaded data records to data collection!")
datasets==2.11.0
qdrant-client==1.7.0
sentence_transformers==2.2.2
tqdm==4.65.0
sentencepiece==0.1.99
transformers==4.36.2
gradio==4.12.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment