Skip to content

Instantly share code, notes, and snippets.

@xmodar
Created February 27, 2021 18:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xmodar/3efc617acac2fc7a686fbdee889d5169 to your computer and use it in GitHub Desktop.
Save xmodar/3efc617acac2fc7a686fbdee889d5169 to your computer and use it in GitHub Desktop.
"""YOLOv3 object detector."""
import math
from pathlib import Path
from urllib.request import urlopen
from PIL import Image
from PIL import ImageColor, ImageOps
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms.functional as TF
from torchvision.datasets.utils import download_url
COCO_CLASSES = ('person', 'bicycle', 'car', 'motorbike', 'aeroplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'sofa', 'pottedplant', 'bed',
'diningtable', 'toilet', 'tvmonitor', 'laptop', 'mouse',
'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
'scissors', 'teddy bear', 'hair drier', 'toothbrush')
def conv_block(ins, outs, kernel_size=3, stride=1, padding=1):
"""Get default Conv2d-BatchNorm2d-LeakyReLU block for Darknet-53."""
return nn.Sequential(
nn.Conv2d(ins, outs, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(outs),
nn.LeakyReLU(0.1),
)
class ResidualUnit(nn.Module):
"""Residual unit for Darknet-53."""
def __init__(self, channels):
super().__init__()
self.net = nn.Sequential(
conv_block(channels, channels // 2, kernel_size=1, padding=0),
conv_block(channels // 2, channels),
)
def forward(self, inputs):
"""Perform forward pass."""
return self.net(inputs) + inputs
@classmethod
def repeated(cls, channels, times):
"""Get a block of multiple residual units."""
return nn.Sequential(*[cls(channels) for _ in range(times)])
def darknet53(in_channels=3):
"""Get Darknet-53 (the feature extractor for YOLOv3)."""
return nn.Sequential(
conv_block(in_channels, 32),
conv_block(32, 64, stride=2),
ResidualUnit.repeated(64, times=1),
conv_block(64, 128, stride=2),
ResidualUnit.repeated(128, times=2),
conv_block(128, 256, stride=2),
ResidualUnit.repeated(256, times=8), # <-- Route 1 (S = I / 8)
conv_block(256, 512, stride=2),
ResidualUnit.repeated(512, times=8), # <-- Route 2 (S = I / 16)
conv_block(512, 1024, stride=2),
ResidualUnit.repeated(1024, times=4), # <-- Route 3 (S = I / 32)
)
class ConcatUnit(nn.Module):
"""Concatenation unit for YOLOv3."""
def __init__(self, in_channels):
super().__init__()
out = math.ceil(in_channels / 2)
self.net = conv_block(in_channels, out, kernel_size=1, padding=0)
def forward(self, inputs, skip):
"""Perform forward pass."""
inputs = F.interpolate(self.net(inputs), skip.shape[-2:])
return torch.cat([inputs, skip], dim=1)
def yolo_block(in_channels, out_channels):
"""Get five convolutional blocks for YOLOv3."""
mid_channels = out_channels * 2
return nn.Sequential(
conv_block(in_channels, out_channels, kernel_size=1, padding=0),
conv_block(out_channels, mid_channels),
conv_block(mid_channels, out_channels, kernel_size=1, padding=0),
conv_block(out_channels, mid_channels),
conv_block(mid_channels, out_channels, kernel_size=1, padding=0),
)
def yolo_head(in_channels, mid_channels, num_anchors, num_classes):
"""Get a detection head for YOLOv3."""
out_channels = (1 + 4 + num_classes) * num_anchors
return nn.Sequential(
conv_block(in_channels, mid_channels),
nn.Conv2d(mid_channels, out_channels, kernel_size=1),
)
class SpreadDim(nn.Module):
"""Spread a tensors dimension into multiple dimensions."""
def __init__(self, *dims, dim):
super().__init__()
self.dims = dims
self.dim = dim
def forward(self, inputs):
"""Perform forward pass."""
shape = list(inputs.shape)
rest = shape.pop(self.dim) // math.prod(self.dims)
dims = list(self.dims) + ([] if rest == 1 else [rest])
return inputs.view(shape[:self.dim] + dims + shape[self.dim:])
class BaseYOLOv3(nn.Module):
"""The base architecture for YOLOv3."""
def __init__(self, in_channels=3, num_anchors=3, num_classes=80):
super().__init__()
self.in_channels = in_channels
self.num_anchors = num_anchors
self.num_classes = num_classes
backbone = darknet53(in_channels)
self.feature1 = backbone[0:7]
self.feature2 = backbone[7:9]
self.feature3 = backbone[9:11]
self.route3 = yolo_block(1024, 512)
self.head3 = yolo_head(512, 1024, num_anchors, num_classes)
self.concat2 = ConcatUnit(512)
self.route2 = yolo_block(768, 256)
self.head2 = yolo_head(256, 512, num_anchors, num_classes)
self.concat1 = ConcatUnit(256)
self.route1 = yolo_block(384, 128)
self.head1 = yolo_head(128, 256, num_anchors, num_classes)
self.spread_dim = SpreadDim(1 + 4 + num_classes, dim=1)
def forward(self, inputs):
"""Perform forward pass."""
route1 = self.feature1(inputs)
route2 = self.feature2(route1)
route3 = self.feature3(route2)
route3 = self.route3(route3)
route2 = self.route2(self.concat2(route3, route2))
route1 = self.route1(self.concat1(route2, route1))
y1_detections = self.head3(route3) # large receptive field
y2_detections = self.head2(route2) # medium receptive field
y3_detections = self.head1(route1) # small receptive field
detections = (y1_detections, y2_detections, y3_detections)
return tuple(map(self.spread_dim, detections))
class YOLOv3(nn.Module):
"""The YOLOv3 object detector."""
original_image_size = 416
anchor_priors = (
((116, 90), (156, 198), (373, 326)),
((30, 61), (62, 45), (59, 119)),
((10, 13), (16, 30), (33, 23)),
)
def __init__(self, in_channels=3, num_classes=80):
super().__init__()
self.in_channels = in_channels
self.num_classes = num_classes
anchors = torch.FloatTensor(self.anchor_priors)
self.register_buffer('anchors', anchors)
self.yolo = BaseYOLOv3(in_channels, len(anchors), num_classes)
def forward(self, images):
"""Perform forward pass."""
# pylint: disable=invalid-name
y1, y2, y3 = self.yolo(images)
a1, a2, a3 = self.anchors
h, w = images.shape[-2:] # image height and width
s1, b1, c1 = self.to_bounding_boxes(y1, a1, w, h) # large boxes
s2, b2, c2 = self.to_bounding_boxes(y2, a2, w, h) # medium boxes
s3, b3, c3 = self.to_bounding_boxes(y3, a3, w, h) # small boxes
confidence_scores = torch.cat([s1, s2, s3], dim=1) # shape [B x N]
bounding_boxes = torch.cat([b1, b2, b3], dim=1) # shape [B x N x 4]
logits = torch.cat([c1, c2, c3], dim=1) # shape [B x N x C]
return confidence_scores, bounding_boxes, logits
@torch.no_grad()
def detect(self, images, iou_threshold=0.6):
"""Perform forward pass in eval mode with NMS."""
self.train(False)
nms = self.non_maximum_suppression
scores, bboxes, labels = nms(*self.forward(images), iou_threshold)
return scores, bboxes, labels
@staticmethod
def to_bounding_boxes(detections, anchors, width, height):
"""Convert the raw output of a YOLOv3 to bounding boxes."""
# anchors shape: [K x 2]
# detections shape: [B x (1 + 4 + C) x K x Sy x Sx]
splits = (1, 1, 1, 1, 1, detections.shape[1] - 5)
scores, t_x, t_y, t_w, t_h, logits = detections.split(splits, dim=1)
scores = scores.flatten(1).sigmoid() # shape [B x (1 * K * Sy * Sx)]
s_y, s_x = detections.shape[-2:]
c_y, c_x = torch.meshgrid(
torch.arange(s_y, dtype=t_y.dtype, device=t_y.device),
torch.arange(s_x, dtype=t_x.dtype, device=t_x.device),
)
b_x = (t_x.sigmoid() + c_x) * (1 / s_x)
b_y = (t_y.sigmoid() + c_y) * (1 / s_y)
p_w, p_h = anchors.T[..., None, None] # shape [2 x K x 1 x 1]
b_w = t_w.exp() * (p_w / width)
b_h = t_h.exp() * (p_h / height)
bboxes = torch.stack([b_x, b_y, b_w, b_h], dim=-1).flatten(1, -2)
logits = logits.flatten(2).transpose(1, 2) # [B x (K * Sy * Sx) x C]
return scores, bboxes, logits
@staticmethod
def non_maximum_suppression(scores, bboxes, logits, iou_threshold=0.6):
"""Apply NMS on a batch and sort the bboxes by adjusted scores."""
outputs = zip(*[
YOLOv3.nms_single_image(*results, iou_threshold)
for results in zip(scores, bboxes, logits)
])
pad = lambda x: nn.utils.rnn.pad_sequence(x, batch_first=True)
adjusted_scores, remaining_bboxes, labels = map(pad, outputs)
return adjusted_scores, remaining_bboxes, labels
@staticmethod
def nms_single_image(scores, bboxes, logits, iou_threshold=0.6):
"""Apply NMS on a single image and sort by adjusted scores."""
nms = torchvision.ops.batched_nms
box_convert = torchvision.ops.box_convert
max_probabilities, labels = logits.softmax(dim=1).max(dim=1)
adjusted_scores = max_probabilities * scores
bboxes = box_convert(bboxes, 'cxcywh', 'xyxy')
indices = nms(bboxes, adjusted_scores, labels, iou_threshold)
bboxes = box_convert(bboxes[indices], 'xyxy', 'cxcywh')
return adjusted_scores[indices], bboxes, labels[indices]
def load_yolov3_weights(model, weights_path=None):
"""Load the weights of YOLOv3."""
# use a default path
if weights_path is None:
hub_dir = Path(torch.hub.get_dir())
weights_path = hub_dir / 'checkpoints/yolov3.weights'
# download the weights file if it doesn't exist
weights_path = Path(weights_path)
if not weights_path.exists():
url = 'https://pjreddie.com/media/files/yolov3.weights'
md5 = 'c84e5b99d0e52cd466ae710cadf6d84c'
download_url(url, weights_path.parent, weights_path.name, md5)
with open(weights_path, 'rb') as weights_file:
# start by reading the header (version + num_training_images)
np.fromfile(weights_file, np.int32, count=3) # major, minor, revision
np.fromfile(weights_file, np.int64, count=1) # num_training_images
# then read the rest of the weights as np.float32 values
def read(layer, *parameters): # order of parameters is important
for parameter_name in parameters:
tensor = getattr(layer, parameter_name).data
raw = np.fromfile(weights_file, np.float32, tensor.numel())
tensor.copy_(torch.from_numpy(raw).to(tensor).view_as(tensor))
# in the file, BN weights are saved before Conv weights
heads = []
for layer in model.modules():
if isinstance(layer, nn.Conv2d):
heads.append(layer) # keep it for after reading BN weights
if layer.bias is not None: # last Conv in yolo_head has bias
# if the number of classes is not 80 (COCO Dataset)
if layer.out_channels != (1 + 4 + 80) * 3:
# create a dummy layer to skip this layer
layer = nn.Conv2d(layer.in_channels, 255, 1)
heads.pop()
read(layer, 'bias', 'weight')
elif isinstance(layer, nn.BatchNorm2d):
read(layer, 'bias', 'weight', 'running_mean', 'running_var')
read(heads.pop(), 'weight') # every BN has a Conv behind it
# organize the channel dimension for the heads from
# the original order to ours K x (4 + 1 + C) -> (1 + 4 + C) x K
if heads:
indices = torch.arange(255).view(3, 4 + 1 + 80).t()
bboxes, scores, classes = indices.split((4, 1, 80))
order = torch.cat([scores, bboxes, classes]).flatten()
for layer in heads:
layer.weight.data.copy_(layer.weight.data[order])
layer.bias.data.copy_(layer.bias.data[order])
return model
def plot_image_with_boxes(image, bboxes, labels, show=True, ax=None):
"""Plot an image with object bounding boxes and their labels."""
# pylint: disable=invalid-name
height, width = image.shape[-2:]
if ax is None:
figure, ax = plt.subplots(figsize=(int(width / height * 10), 10))
ax.imshow(TF.to_pil_image(image))
count = 0
num_colors = 10
colors = plt.cm.get_cmap('hsv', num_colors)
for bbox, label in zip(bboxes, labels):
color = colors(count % num_colors)
count += 1
b_x, b_y, b_w, b_h = bbox.cpu()
top = (b_y - b_h / 2) * height
left = (b_x - b_w / 2) * width
ax.add_patch(
patches.Rectangle(
(left, top),
b_w * width,
b_h * height,
linewidth=2,
fill=False,
color=color,
))
ax.text(
left,
top,
label,
color='black',
verticalalignment='top',
fontsize='x-large',
bbox={
'color': color,
'pad': 1
},
)
ax.axis('off')
if show:
plt.show(figure)
return ax
def square_image(image, size=None, color=(0, 0, 0)):
"""Pad and resize a PIL.Image to make it square (width == height)."""
shape = (max(image.size) if size is None else size, ) * 2
color = ImageColor.getcolor(f'rgb{tuple(color)}', image.mode)
return ImageOps.pad(image, shape, Image.NEAREST, color)
def test_detect_image(image_path, min_score=0.9, iou_threshold=0.6, square=False):
"""Apply COCO detection on an image with YOLOv3."""
if str(image_path).startswith('http'):
image_path = urlopen(image_path)
image = Image.open(image_path)
if square:
image = square_image(image, color=[128] * 3)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
images = TF.to_tensor(image).unsqueeze(0).to(device)
model = load_yolov3_weights(YOLOv3()).to(images.device)
detections = model.detect(images, iou_threshold)
for image, scores, bboxes, labels in zip(images, *detections):
valid = scores > min_score
labels = [
f'{COCO_CLASSES[label]} ({score.item() * 100:.2f}%)'
for score, label in zip(scores[valid], labels[valid])
]
plot_image_with_boxes(image, bboxes[valid], labels)
if __name__ == '__main__':
test_detect_image(
'https://raw.githubusercontent.com/pjreddie/darknet/master/data/dog.jpg'
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment