Created
December 13, 2023 10:03
-
-
Save masadcv/a201b8da9e82d71a287fd2f6b3152cdd to your computer and use it in GitHub Desktop.
OWL-VIT Object Detection
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import requests | |
from PIL import Image | |
import torch | |
import matplotlib.pyplot as plt | |
from transformers import OwlViTProcessor, OwlViTForObjectDetection | |
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") | |
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32") | |
url = "http://images.cocodataset.org/val2017/000000039769.jpg" | |
image = Image.open(requests.get(url, stream=True).raw) | |
texts = [["a photo of a cat", "a photo of a dog"]] | |
inputs = processor(text=texts, images=image, return_tensors="pt") | |
outputs = model(**inputs) | |
# Target image sizes (height, width) to rescale box predictions [batch_size, 2] | |
target_sizes = torch.Tensor([image.size[::-1]]) | |
# Convert outputs (bounding boxes and class logits) to COCO API | |
results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=0.1) | |
i = 0 # Retrieve predictions for the first image for the corresponding text queries | |
text = texts[i] | |
boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"] | |
boxes_to_draw = [] | |
for box, score, label in zip(boxes, scores, labels): | |
box = [round(i, 2) for i in box.tolist()] | |
print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}") | |
boxes_to_draw.append(box) | |
# plot image and show bounding box | |
plt.imshow(image) | |
plt.axis("off") | |
for box, scores, labels in zip(boxes_to_draw, scores, labels): | |
bba = box[2:] | |
bbb = box[:2] | |
bbc = [b - a for a, b in zip(bbb, bba)] | |
plt.gca().add_patch(plt.Rectangle(box[:2], *(bbc), fill=False, edgecolor="r", linewidth=3)) | |
plt.text(box[0], box[1], s=text[labels], color="white", verticalalignment="top", bbox={"color": "red", "pad": 0}) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment