Skip to content

Instantly share code, notes, and snippets.

@younesbelkada
Last active December 27, 2022 14:55
Show Gist options
  • Save younesbelkada/bf03bc47bd59ce647462250c42f7d55f to your computer and use it in GitHub Desktop.
Save younesbelkada/bf03bc47bd59ce647462250c42f7d55f to your computer and use it in GitHub Desktop.
Evaluate HF object detection models on COCO using `evaluate`
import os
import evaluate
import torch
import torchvision
from tqdm import tqdm
from transformers import DetrFeatureExtractor, DetrForObjectDetection
COCO_DIR = os.path.join(os.getcwd(), "data")
path_img = os.path.join(COCO_DIR, "val2017")
path_anno = os.path.join(COCO_DIR, "annotations/instances_val2017.json")
class CocoDetection(torchvision.datasets.CocoDetection):
def __init__(self, img_folder, feature_extractor, ann_file):
super(CocoDetection, self).__init__(img_folder, ann_file)
self.feature_extractor = feature_extractor
def __getitem__(self, idx):
# read in PIL image and target in COCO format
img, target = super(CocoDetection, self).__getitem__(idx)
# preprocess image and target (converting target to DETR format, resizing + normalization of both image and target)
image_id = self.ids[idx]
target = {'image_id': image_id, 'annotations': target}
encoding = self.feature_extractor(images=img, annotations=target, return_tensors="pt")
pixel_values = encoding["pixel_values"].squeeze() # remove batch dimension
target = encoding["labels"][0] # remove batch dimension
return pixel_values, target
def collate_fn(batch):
pixel_values = [item[0] for item in batch]
encoding = feature_extractor.pad_and_create_pixel_mask(pixel_values, return_tensors="pt")
labels = [item[1] for item in batch]
batch = {}
batch['pixel_values'] = encoding['pixel_values']
batch['pixel_mask'] = encoding['pixel_mask']
batch['labels'] = labels
return batch
feature_extractor = DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").to("cuda")
coco_dataset = CocoDetection(path_img, feature_extractor, path_anno)
module = evaluate.load("ybelkada/cocoevaluate", coco=coco_dataset.coco)
device = model.device
val_dataloader = torch.utils.data.DataLoader(coco_dataset, batch_size=8, shuffle=False, num_workers=4, collate_fn=collate_fn)
with torch.no_grad():
for idx, batch in enumerate(tqdm(val_dataloader)):
# set to device
pixel_values = batch["pixel_values"].to(device)
pixel_mask = batch["pixel_mask"].to(device)
labels = [{k: v.to(device) for k, v in t.items()} for t in batch["labels"]] # these are in DETR format, resized + normalized
# forward pass
outputs = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
orig_target_sizes = torch.stack([target["orig_size"] for target in labels], dim=0)
results = feature_extractor.post_process(outputs, orig_target_sizes) # convert outputs of model to COCO api
module.add(prediction=results, reference=labels)
del batch
results = module.compute()
print(results)
@younesbelkada
Copy link
Author

younesbelkada commented Dec 27, 2022

How to run the script?

You essentially need 2 steps to run the evaluation script.

Step 1: install pycoco

pip install pycoco

Step 2: Get the data

Download the data on your preferred location following this tutorial. E.g. for coco-val you will need to:

wget http://images.cocodataset.org/zips/val2017.zip && unzip val2017.zip
wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip && unzip annotations_trainval2017.zip

And of course make sure the variable COCO_DIR is accurate

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment