Created
April 7, 2023 21:36
-
-
Save richardliaw/0b04d28254a5fc2cf8c72fbf1b63aa82 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import DetrImageProcessor, DetrForObjectDetection | |
import torch | |
from PIL import Image | |
import requests | |
import boto3 | |
import io | |
import time | |
import pandas as pd | |
import numpy as np | |
from ray.train.huggingface import HuggingFacePredictor | |
import ray | |
data_url = 's3://air-example-data/AnimalDetection/JPEGImages' | |
ds = ray.data.read_images(data_url).limit(10) | |
processor: DetrImageProcessor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
model: DetrForObjectDetection = DetrForObjectDetection.from_pretrained( | |
"facebook/detr-resnet-50") | |
# preprocess_single_image(test_image) | |
# # OK, it seems the imput is IMAGES | |
def preprocess_detr(image): | |
# print(f'Incoming image is: {image["image"]}') | |
inputs = processor.preprocess( | |
images=image["image"].tolist(), | |
return_tensors="pt") | |
outputs = model(**inputs) | |
target_sizes = [x.shape[::-1] for x in image["image"].tolist()] | |
results = processor.post_process_object_detection( | |
outputs, threshold=0.9, target_sizes=target_sizes)[0] | |
boxes = [] | |
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | |
boxes.append([round(i, 2) for i in box.tolist()]) | |
return boxes | |
preprocessed = ds.map_batches(preprocess_detr, batch_format="numpy") | |
print(f'Preprocessed is: {preprocessed.take(1)}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment