Created
February 27, 2021 18:04
-
-
Save xmodar/3efc617acac2fc7a686fbdee889d5169 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""YOLOv3 object detector.""" | |
import math | |
from pathlib import Path | |
from urllib.request import urlopen | |
from PIL import Image | |
from PIL import ImageColor, ImageOps | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as patches | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision | |
import torchvision.transforms.functional as TF | |
from torchvision.datasets.utils import download_url | |
COCO_CLASSES = ('person', 'bicycle', 'car', 'motorbike', 'aeroplane', 'bus', | |
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', | |
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', | |
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', | |
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', | |
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |
'kite', 'baseball bat', 'baseball glove', 'skateboard', | |
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', | |
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | |
'donut', 'cake', 'chair', 'sofa', 'pottedplant', 'bed', | |
'diningtable', 'toilet', 'tvmonitor', 'laptop', 'mouse', | |
'remote', 'keyboard', 'cell phone', 'microwave', 'oven', | |
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', | |
'scissors', 'teddy bear', 'hair drier', 'toothbrush') | |
def conv_block(ins, outs, kernel_size=3, stride=1, padding=1): | |
"""Get default Conv2d-BatchNorm2d-LeakyReLU block for Darknet-53.""" | |
return nn.Sequential( | |
nn.Conv2d(ins, outs, kernel_size, stride, padding, bias=False), | |
nn.BatchNorm2d(outs), | |
nn.LeakyReLU(0.1), | |
) | |
class ResidualUnit(nn.Module): | |
"""Residual unit for Darknet-53.""" | |
def __init__(self, channels): | |
super().__init__() | |
self.net = nn.Sequential( | |
conv_block(channels, channels // 2, kernel_size=1, padding=0), | |
conv_block(channels // 2, channels), | |
) | |
def forward(self, inputs): | |
"""Perform forward pass.""" | |
return self.net(inputs) + inputs | |
@classmethod | |
def repeated(cls, channels, times): | |
"""Get a block of multiple residual units.""" | |
return nn.Sequential(*[cls(channels) for _ in range(times)]) | |
def darknet53(in_channels=3): | |
"""Get Darknet-53 (the feature extractor for YOLOv3).""" | |
return nn.Sequential( | |
conv_block(in_channels, 32), | |
conv_block(32, 64, stride=2), | |
ResidualUnit.repeated(64, times=1), | |
conv_block(64, 128, stride=2), | |
ResidualUnit.repeated(128, times=2), | |
conv_block(128, 256, stride=2), | |
ResidualUnit.repeated(256, times=8), # <-- Route 1 (S = I / 8) | |
conv_block(256, 512, stride=2), | |
ResidualUnit.repeated(512, times=8), # <-- Route 2 (S = I / 16) | |
conv_block(512, 1024, stride=2), | |
ResidualUnit.repeated(1024, times=4), # <-- Route 3 (S = I / 32) | |
) | |
class ConcatUnit(nn.Module): | |
"""Concatenation unit for YOLOv3.""" | |
def __init__(self, in_channels): | |
super().__init__() | |
out = math.ceil(in_channels / 2) | |
self.net = conv_block(in_channels, out, kernel_size=1, padding=0) | |
def forward(self, inputs, skip): | |
"""Perform forward pass.""" | |
inputs = F.interpolate(self.net(inputs), skip.shape[-2:]) | |
return torch.cat([inputs, skip], dim=1) | |
def yolo_block(in_channels, out_channels): | |
"""Get five convolutional blocks for YOLOv3.""" | |
mid_channels = out_channels * 2 | |
return nn.Sequential( | |
conv_block(in_channels, out_channels, kernel_size=1, padding=0), | |
conv_block(out_channels, mid_channels), | |
conv_block(mid_channels, out_channels, kernel_size=1, padding=0), | |
conv_block(out_channels, mid_channels), | |
conv_block(mid_channels, out_channels, kernel_size=1, padding=0), | |
) | |
def yolo_head(in_channels, mid_channels, num_anchors, num_classes): | |
"""Get a detection head for YOLOv3.""" | |
out_channels = (1 + 4 + num_classes) * num_anchors | |
return nn.Sequential( | |
conv_block(in_channels, mid_channels), | |
nn.Conv2d(mid_channels, out_channels, kernel_size=1), | |
) | |
class SpreadDim(nn.Module): | |
"""Spread a tensors dimension into multiple dimensions.""" | |
def __init__(self, *dims, dim): | |
super().__init__() | |
self.dims = dims | |
self.dim = dim | |
def forward(self, inputs): | |
"""Perform forward pass.""" | |
shape = list(inputs.shape) | |
rest = shape.pop(self.dim) // math.prod(self.dims) | |
dims = list(self.dims) + ([] if rest == 1 else [rest]) | |
return inputs.view(shape[:self.dim] + dims + shape[self.dim:]) | |
class BaseYOLOv3(nn.Module): | |
"""The base architecture for YOLOv3.""" | |
def __init__(self, in_channels=3, num_anchors=3, num_classes=80): | |
super().__init__() | |
self.in_channels = in_channels | |
self.num_anchors = num_anchors | |
self.num_classes = num_classes | |
backbone = darknet53(in_channels) | |
self.feature1 = backbone[0:7] | |
self.feature2 = backbone[7:9] | |
self.feature3 = backbone[9:11] | |
self.route3 = yolo_block(1024, 512) | |
self.head3 = yolo_head(512, 1024, num_anchors, num_classes) | |
self.concat2 = ConcatUnit(512) | |
self.route2 = yolo_block(768, 256) | |
self.head2 = yolo_head(256, 512, num_anchors, num_classes) | |
self.concat1 = ConcatUnit(256) | |
self.route1 = yolo_block(384, 128) | |
self.head1 = yolo_head(128, 256, num_anchors, num_classes) | |
self.spread_dim = SpreadDim(1 + 4 + num_classes, dim=1) | |
def forward(self, inputs): | |
"""Perform forward pass.""" | |
route1 = self.feature1(inputs) | |
route2 = self.feature2(route1) | |
route3 = self.feature3(route2) | |
route3 = self.route3(route3) | |
route2 = self.route2(self.concat2(route3, route2)) | |
route1 = self.route1(self.concat1(route2, route1)) | |
y1_detections = self.head3(route3) # large receptive field | |
y2_detections = self.head2(route2) # medium receptive field | |
y3_detections = self.head1(route1) # small receptive field | |
detections = (y1_detections, y2_detections, y3_detections) | |
return tuple(map(self.spread_dim, detections)) | |
class YOLOv3(nn.Module): | |
"""The YOLOv3 object detector.""" | |
original_image_size = 416 | |
anchor_priors = ( | |
((116, 90), (156, 198), (373, 326)), | |
((30, 61), (62, 45), (59, 119)), | |
((10, 13), (16, 30), (33, 23)), | |
) | |
def __init__(self, in_channels=3, num_classes=80): | |
super().__init__() | |
self.in_channels = in_channels | |
self.num_classes = num_classes | |
anchors = torch.FloatTensor(self.anchor_priors) | |
self.register_buffer('anchors', anchors) | |
self.yolo = BaseYOLOv3(in_channels, len(anchors), num_classes) | |
def forward(self, images): | |
"""Perform forward pass.""" | |
# pylint: disable=invalid-name | |
y1, y2, y3 = self.yolo(images) | |
a1, a2, a3 = self.anchors | |
h, w = images.shape[-2:] # image height and width | |
s1, b1, c1 = self.to_bounding_boxes(y1, a1, w, h) # large boxes | |
s2, b2, c2 = self.to_bounding_boxes(y2, a2, w, h) # medium boxes | |
s3, b3, c3 = self.to_bounding_boxes(y3, a3, w, h) # small boxes | |
confidence_scores = torch.cat([s1, s2, s3], dim=1) # shape [B x N] | |
bounding_boxes = torch.cat([b1, b2, b3], dim=1) # shape [B x N x 4] | |
logits = torch.cat([c1, c2, c3], dim=1) # shape [B x N x C] | |
return confidence_scores, bounding_boxes, logits | |
@torch.no_grad() | |
def detect(self, images, iou_threshold=0.6): | |
"""Perform forward pass in eval mode with NMS.""" | |
self.train(False) | |
nms = self.non_maximum_suppression | |
scores, bboxes, labels = nms(*self.forward(images), iou_threshold) | |
return scores, bboxes, labels | |
@staticmethod | |
def to_bounding_boxes(detections, anchors, width, height): | |
"""Convert the raw output of a YOLOv3 to bounding boxes.""" | |
# anchors shape: [K x 2] | |
# detections shape: [B x (1 + 4 + C) x K x Sy x Sx] | |
splits = (1, 1, 1, 1, 1, detections.shape[1] - 5) | |
scores, t_x, t_y, t_w, t_h, logits = detections.split(splits, dim=1) | |
scores = scores.flatten(1).sigmoid() # shape [B x (1 * K * Sy * Sx)] | |
s_y, s_x = detections.shape[-2:] | |
c_y, c_x = torch.meshgrid( | |
torch.arange(s_y, dtype=t_y.dtype, device=t_y.device), | |
torch.arange(s_x, dtype=t_x.dtype, device=t_x.device), | |
) | |
b_x = (t_x.sigmoid() + c_x) * (1 / s_x) | |
b_y = (t_y.sigmoid() + c_y) * (1 / s_y) | |
p_w, p_h = anchors.T[..., None, None] # shape [2 x K x 1 x 1] | |
b_w = t_w.exp() * (p_w / width) | |
b_h = t_h.exp() * (p_h / height) | |
bboxes = torch.stack([b_x, b_y, b_w, b_h], dim=-1).flatten(1, -2) | |
logits = logits.flatten(2).transpose(1, 2) # [B x (K * Sy * Sx) x C] | |
return scores, bboxes, logits | |
@staticmethod | |
def non_maximum_suppression(scores, bboxes, logits, iou_threshold=0.6): | |
"""Apply NMS on a batch and sort the bboxes by adjusted scores.""" | |
outputs = zip(*[ | |
YOLOv3.nms_single_image(*results, iou_threshold) | |
for results in zip(scores, bboxes, logits) | |
]) | |
pad = lambda x: nn.utils.rnn.pad_sequence(x, batch_first=True) | |
adjusted_scores, remaining_bboxes, labels = map(pad, outputs) | |
return adjusted_scores, remaining_bboxes, labels | |
@staticmethod | |
def nms_single_image(scores, bboxes, logits, iou_threshold=0.6): | |
"""Apply NMS on a single image and sort by adjusted scores.""" | |
nms = torchvision.ops.batched_nms | |
box_convert = torchvision.ops.box_convert | |
max_probabilities, labels = logits.softmax(dim=1).max(dim=1) | |
adjusted_scores = max_probabilities * scores | |
bboxes = box_convert(bboxes, 'cxcywh', 'xyxy') | |
indices = nms(bboxes, adjusted_scores, labels, iou_threshold) | |
bboxes = box_convert(bboxes[indices], 'xyxy', 'cxcywh') | |
return adjusted_scores[indices], bboxes, labels[indices] | |
def load_yolov3_weights(model, weights_path=None): | |
"""Load the weights of YOLOv3.""" | |
# use a default path | |
if weights_path is None: | |
hub_dir = Path(torch.hub.get_dir()) | |
weights_path = hub_dir / 'checkpoints/yolov3.weights' | |
# download the weights file if it doesn't exist | |
weights_path = Path(weights_path) | |
if not weights_path.exists(): | |
url = 'https://pjreddie.com/media/files/yolov3.weights' | |
md5 = 'c84e5b99d0e52cd466ae710cadf6d84c' | |
download_url(url, weights_path.parent, weights_path.name, md5) | |
with open(weights_path, 'rb') as weights_file: | |
# start by reading the header (version + num_training_images) | |
np.fromfile(weights_file, np.int32, count=3) # major, minor, revision | |
np.fromfile(weights_file, np.int64, count=1) # num_training_images | |
# then read the rest of the weights as np.float32 values | |
def read(layer, *parameters): # order of parameters is important | |
for parameter_name in parameters: | |
tensor = getattr(layer, parameter_name).data | |
raw = np.fromfile(weights_file, np.float32, tensor.numel()) | |
tensor.copy_(torch.from_numpy(raw).to(tensor).view_as(tensor)) | |
# in the file, BN weights are saved before Conv weights | |
heads = [] | |
for layer in model.modules(): | |
if isinstance(layer, nn.Conv2d): | |
heads.append(layer) # keep it for after reading BN weights | |
if layer.bias is not None: # last Conv in yolo_head has bias | |
# if the number of classes is not 80 (COCO Dataset) | |
if layer.out_channels != (1 + 4 + 80) * 3: | |
# create a dummy layer to skip this layer | |
layer = nn.Conv2d(layer.in_channels, 255, 1) | |
heads.pop() | |
read(layer, 'bias', 'weight') | |
elif isinstance(layer, nn.BatchNorm2d): | |
read(layer, 'bias', 'weight', 'running_mean', 'running_var') | |
read(heads.pop(), 'weight') # every BN has a Conv behind it | |
# organize the channel dimension for the heads from | |
# the original order to ours K x (4 + 1 + C) -> (1 + 4 + C) x K | |
if heads: | |
indices = torch.arange(255).view(3, 4 + 1 + 80).t() | |
bboxes, scores, classes = indices.split((4, 1, 80)) | |
order = torch.cat([scores, bboxes, classes]).flatten() | |
for layer in heads: | |
layer.weight.data.copy_(layer.weight.data[order]) | |
layer.bias.data.copy_(layer.bias.data[order]) | |
return model | |
def plot_image_with_boxes(image, bboxes, labels, show=True, ax=None): | |
"""Plot an image with object bounding boxes and their labels.""" | |
# pylint: disable=invalid-name | |
height, width = image.shape[-2:] | |
if ax is None: | |
figure, ax = plt.subplots(figsize=(int(width / height * 10), 10)) | |
ax.imshow(TF.to_pil_image(image)) | |
count = 0 | |
num_colors = 10 | |
colors = plt.cm.get_cmap('hsv', num_colors) | |
for bbox, label in zip(bboxes, labels): | |
color = colors(count % num_colors) | |
count += 1 | |
b_x, b_y, b_w, b_h = bbox.cpu() | |
top = (b_y - b_h / 2) * height | |
left = (b_x - b_w / 2) * width | |
ax.add_patch( | |
patches.Rectangle( | |
(left, top), | |
b_w * width, | |
b_h * height, | |
linewidth=2, | |
fill=False, | |
color=color, | |
)) | |
ax.text( | |
left, | |
top, | |
label, | |
color='black', | |
verticalalignment='top', | |
fontsize='x-large', | |
bbox={ | |
'color': color, | |
'pad': 1 | |
}, | |
) | |
ax.axis('off') | |
if show: | |
plt.show(figure) | |
return ax | |
def square_image(image, size=None, color=(0, 0, 0)): | |
"""Pad and resize a PIL.Image to make it square (width == height).""" | |
shape = (max(image.size) if size is None else size, ) * 2 | |
color = ImageColor.getcolor(f'rgb{tuple(color)}', image.mode) | |
return ImageOps.pad(image, shape, Image.NEAREST, color) | |
def test_detect_image(image_path, min_score=0.9, iou_threshold=0.6, square=False): | |
"""Apply COCO detection on an image with YOLOv3.""" | |
if str(image_path).startswith('http'): | |
image_path = urlopen(image_path) | |
image = Image.open(image_path) | |
if square: | |
image = square_image(image, color=[128] * 3) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
images = TF.to_tensor(image).unsqueeze(0).to(device) | |
model = load_yolov3_weights(YOLOv3()).to(images.device) | |
detections = model.detect(images, iou_threshold) | |
for image, scores, bboxes, labels in zip(images, *detections): | |
valid = scores > min_score | |
labels = [ | |
f'{COCO_CLASSES[label]} ({score.item() * 100:.2f}%)' | |
for score, label in zip(scores[valid], labels[valid]) | |
] | |
plot_image_with_boxes(image, bboxes[valid], labels) | |
if __name__ == '__main__': | |
test_detect_image( | |
'https://raw.githubusercontent.com/pjreddie/darknet/master/data/dog.jpg' | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment