This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
bboxes are sorted by score in decreasing order | |
init a vector keep with ones | |
for i in len(bboxes): | |
# was suppressed | |
if keep[i] == 0: | |
continue | |
# compare with all the others | |
for j in len(bbox): | |
if keep[j]: | |
if (iou(bboxes[i], bboxes[j]) > iou_threshold): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
plot_bboxes(img, bboxes, | |
colors=["yellow" if el.item() == 0 else "blue" for el in labels], | |
labels=["head" if el.item() == 0 else "mic" for el in labels] | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
bboxes = bboxes[perm] | |
scores = scores[perm] | |
labels = labels[perm] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
perm = torch.randperm(scores.shape[0]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
labels = torch.tensor([0,0,0,0,1,1,1,1]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
scores = torch.tensor([ | |
0.98, 0.85, 0.5, 0.2, # for head | |
1, 0.92, 0.3, 0.1 # for mic | |
]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
max_bboxes = 3 | |
scaling = torch.tensor([1, .96, .97, 1.02]) | |
shifting = torch.tensor([0, 0.001, 0.002, -0.002]) | |
# broadcasting magic (2, 1, 4) * (1, 3, 1) | |
bboxes = (original_bboxes[:,None,:] * scaling[..., None] + shifting[..., None]).view(-1, 4) | |
plot_bboxes(img, bboxes, colors=[*["yellow"] * 4, *["blue"] * 4], labels=[*["head"] * 4, *["mic"] * 4]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torchvision.utils import draw_bounding_boxes | |
from torchvision.transforms.functional import to_tensor | |
from typing import List | |
def plot_bboxes(img : Image.Image, bboxes: torch.Tensor, *args, **kwargs) -> plt.Figure: | |
w, h = img.size | |
# from [0, 1] to image size | |
bboxes = bboxes.clone() | |
bboxes[...,0] *= h | |
bboxes[...,1] *= w |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
original_bboxes = torch.tensor([ | |
# head | |
[ 565, 73, 862, 373], | |
# mic | |
[807, 309, 865, 434] | |
]).float() | |
w, h = img.size | |
# we need them in range [0, 1] | |
original_bboxes[...,0] /= h |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from PIL import Image | |
import torch | |
import matplotlib.pyplot as plt | |
import numpy as np | |
# credit https://i0.wp.com/craffic.co.in/wp-content/uploads/2021/02/ai-remastered-rick-astley-never-gonna-give-you-up.jpg?w=1600&ssl=1 | |
img = Image.open("./samples/never-gonna-give-you-up.webp") | |
img |