Skip to content

Instantly share code, notes, and snippets.

Created June 28, 2020 16:11
Show Gist options
  • Save fmassa/d443b1fdca2a5d5debca902218f3232b to your computer and use it in GitHub Desktop.
Save fmassa/d443b1fdca2a5d5debca902218f3232b to your computer and use it in GitHub Desktop.
Code to reproduce Fig 7 in "End to End Object Detection with Transformers"
# this file needs to be added to the root folder of detr github repo
import torch
import time
import torchvision
import numpy as np
import tqdm
import matplotlib.pyplot as plt
from datasets import build_dataset
def get_dataset(coco_path):
Gets the COCO dataset used for computing detections
class DummyArgs:
args = DummyArgs()
args.dataset_file = "coco"
args.coco_path = coco_path
args.masks = False
dataset = build_dataset(image_set='val', args=args)
return dataset
def compute_predictions(model, dataset):
predictions = []
with torch.no_grad():
for i in tqdm.tqdm(range(len(dataset))):
image, target = dataset[i]
out = model([])
res = out['pred_boxes'].cpu()
preds =, 0)
return preds
# need to modify this with the path to your COCO file
PATH_TO_COCO = "/path/to/coco/"
dataset = get_dataset(PATH_TO_COCO)
model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
preds = compute_predictions(model, dataset)
s = (20, 4)
fig = plt.figure(figsize=s)
n = 10
for idx, query in enumerate(range(n * 2), 1):
ax = fig.add_subplot(2, n, idx)
p = preds[:, query]
assert p.min() >= 0
assert p.max() <= 1
cx, cy, w, h = p.unbind(-1)
area = (w * h) ** 0.5 * 10
color = (w * h) ** 0.5
color = torch.stack((w, 1 - color, h), 1)
plt.scatter(cx, cy, c=color, s=area, alpha=0.75)
plt.xlim(0, 1)
plt.ylim(0, 1)
plt.savefig('query_distribution.png', bbox_inches='tight')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment