Last active
June 18, 2024 03:24
-
-
Save Y-T-G/8f4fc0b78a0a559a06fe84ae4f359e6e to your computer and use it in GitHub Desktop.
Patches the YOLO ultralytics library to support extra head for additional classes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
From e41525e084a6f77980665da715ec93070c397479 Mon Sep 17 00:00:00 2001 | |
From: Y-T-G <> | |
Date: Sat, 16 Mar 2024 14:06:25 +0800 | |
Subject: [PATCH] Add ConcatHead | |
--- | |
ultralytics/cfg/models/v8/yolov8n-2xhead.yaml | 55 ++++++++++++++++ | |
ultralytics/engine/trainer.py | 2 +- | |
ultralytics/nn/modules/__init__.py | 2 + | |
ultralytics/nn/modules/conv.py | 64 +++++++++++++++++++ | |
ultralytics/nn/tasks.py | 36 ++++++----- | |
5 files changed, 142 insertions(+), 17 deletions(-) | |
create mode 100644 ultralytics/cfg/models/v8/yolov8n-2xhead.yaml | |
diff --git a/ultralytics/cfg/models/v8/yolov8n-2xhead.yaml b/ultralytics/cfg/models/v8/yolov8n-2xhead.yaml | |
new file mode 100644 | |
index 0000000..a4a0b50 | |
--- /dev/null | |
+++ b/ultralytics/cfg/models/v8/yolov8n-2xhead.yaml | |
@@ -0,0 +1,55 @@ | |
+# Ultralytics YOLO 🚀, AGPL-3.0 license | |
+# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect | |
+ | |
+# Parameters | |
+nc: 82 # number of classes | |
+scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' | |
+ # [depth, width, max_channels] | |
+ n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs | |
+ s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs | |
+ m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs | |
+ l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs | |
+ x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs | |
+ | |
+# YOLOv8.0n backbone | |
+backbone: | |
+ # [from, repeats, module, args] | |
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 | |
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 | |
+ - [-1, 3, C2f, [128, True]] | |
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 | |
+ - [-1, 6, C2f, [256, True]] | |
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 | |
+ - [-1, 6, C2f, [512, True]] | |
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 | |
+ - [-1, 3, C2f, [1024, True]] | |
+ - [-1, 1, SPPF, [1024, 5]] # 9 | |
+ | |
+# YOLOv8.0n head | |
+head: | |
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]] | |
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4 | |
+ - [-1, 3, C2f, [512]] # 12 | |
+ | |
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]] | |
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3 | |
+ - [-1, 3, C2f, [256]] # 15 (P3/8-small) | |
+ | |
+ - [-1, 1, Conv, [256, 3, 2]] | |
+ - [[-1, 12], 1, Concat, [1]] # cat head P4 | |
+ - [-1, 3, C2f, [512]] # 18 (P4/16-medium) | |
+ | |
+ - [-1, 1, Conv, [512, 3, 2]] | |
+ - [[-1, 9], 1, Concat, [1]] # cat head P5 | |
+ - [-1, 3, C2f, [1024]] # 21 (P5/32-large) | |
+ | |
+ # Layer 22. First list indicates the layers whose output are used as input. | |
+ # We indicate that this head has 80 outputs. | |
+ - [[15, 18, 21], 1, Detect, [80]] # Detect(P3, P4, P5) | |
+ | |
+ # Layer 23. We indicate that this head has 2 outputs, because that's how | |
+ # many new classes we added. | |
+ - [[15, 18, 21], 1, Detect, [2]] # Detect(P3, P4, P5) #23 new classes | |
+ | |
+ # ConcatHead takes in layer 22 and 23 and concatenates the output. | |
+ - [[22, 23], 1, ConcatHead, [80, 2]] # Concat #22 and #23 | |
diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py | |
index f005f34..4a025a2 100644 | |
--- a/ultralytics/engine/trainer.py | |
+++ b/ultralytics/engine/trainer.py | |
@@ -340,8 +340,8 @@ class BaseTrainer: | |
epoch = self.start_epoch | |
while True: | |
self.epoch = epoch | |
- self.run_callbacks("on_train_epoch_start") | |
self.model.train() | |
+ self.run_callbacks("on_train_epoch_start") | |
if RANK != -1: | |
self.train_loader.sampler.set_epoch(epoch) | |
pbar = enumerate(self.train_loader) | |
diff --git a/ultralytics/nn/modules/__init__.py b/ultralytics/nn/modules/__init__.py | |
index d785c00..36208cc 100644 | |
--- a/ultralytics/nn/modules/__init__.py | |
+++ b/ultralytics/nn/modules/__init__.py | |
@@ -51,6 +51,7 @@ from .conv import ( | |
CBAM, | |
ChannelAttention, | |
Concat, | |
+ ConcatHead, | |
Conv, | |
Conv2, | |
ConvTranspose, | |
@@ -90,6 +91,7 @@ __all__ = ( | |
"SpatialAttention", | |
"CBAM", | |
"Concat", | |
+ "ConcatHead", | |
"TransformerLayer", | |
"TransformerBlock", | |
"MLPBlock", | |
diff --git a/ultralytics/nn/modules/conv.py b/ultralytics/nn/modules/conv.py | |
index 399c422..24dc844 100644 | |
--- a/ultralytics/nn/modules/conv.py | |
+++ b/ultralytics/nn/modules/conv.py | |
@@ -20,6 +20,7 @@ __all__ = ( | |
"SpatialAttention", | |
"CBAM", | |
"Concat", | |
+ "ConcatHead", | |
"RepConv", | |
) | |
@@ -331,3 +332,66 @@ class Concat(nn.Module): | |
def forward(self, x): | |
"""Forward pass for the YOLOv8 mask Proto module.""" | |
return torch.cat(x, self.d) | |
+ | |
+ | |
+class ConcatHead(nn.Module): | |
+ """Concatenaion layer for Detect heads.""" | |
+ | |
+ def __init__(self, nc1=80, nc2=1, ch=()): | |
+ """Initializes the ConcatHead.""" | |
+ super().__init__() | |
+ self.nc1 = nc1 # number of classes of head 1 | |
+ self.nc2 = nc2 # number of classes of head 2 | |
+ | |
+ def forward(self, x): | |
+ """Concatenates and returns predicted bounding boxes and class probabilities.""" | |
+ | |
+ # x is a list of length 2 | |
+ # Each element is either a tuple or just the decoded features | |
+ # depending whether it's being exported. | |
+ # First element of tuple are the decoded preds, | |
+ # second element are feature maps for heatmap visualization | |
+ | |
+ if isinstance(x[0], tuple): | |
+ preds1 = x[0][0] | |
+ preds2 = x[1][0] | |
+ elif isinstance(x[0], list): # when returned raw outputs | |
+ # The shape is used for stride creation in tasks.py. | |
+ # Feature maps will have to be decoded individually if used as they can't be merged. | |
+ return [torch.cat((x0, x1), dim=1) for x0, x1 in zip(x[0], x[1])] | |
+ else: | |
+ preds1 = x[0] | |
+ preds2 = x[1] | |
+ | |
+ # Concatenate the new head outputs as extra outputs | |
+ | |
+ # 1. Concatenate bbox outputs | |
+ # Shape changes from [N, 4, 6300] to [N, 4, 12600] | |
+ preds = torch.cat((preds1[:, :4, :], preds2[:, :4, :]), dim=2) | |
+ | |
+ # 2. Concatenate class outputs | |
+ # Append preds 1 with empty outputs of size 6300 | |
+ shape = list(preds1.shape) | |
+ shape[-1] *= 2 | |
+ | |
+ preds1_extended = torch.zeros(shape, device=preds1.device, | |
+ dtype=preds1.dtype) | |
+ preds1_extended[..., : preds1.shape[-1]] = preds1 | |
+ | |
+ # Prepend preds 2 with empty outputs of size 6300 | |
+ shape = list(preds2.shape) | |
+ shape[-1] *= 2 | |
+ | |
+ preds2_extended = torch.zeros(shape, device=preds2.device, | |
+ dtype=preds2.dtype) | |
+ preds2_extended[..., preds2.shape[-1] :] = preds2 | |
+ | |
+ # Arrange the class probabilities in order preds1, preds2. The | |
+ # class indices of preds2 will therefore start after preds1 | |
+ preds = torch.cat((preds, preds1_extended[:, 4:, :]), dim=1) | |
+ preds = torch.cat((preds, preds2_extended[:, 4:, :]), dim=1) | |
+ | |
+ if isinstance(x[0], tuple): | |
+ return (preds, x[0][1]) | |
+ else: | |
+ return preds | |
diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py | |
index 64ee7f5..c8a3e14 100644 | |
--- a/ultralytics/nn/tasks.py | |
+++ b/ultralytics/nn/tasks.py | |
@@ -25,6 +25,7 @@ from ultralytics.nn.modules import ( | |
C3x, | |
Classify, | |
Concat, | |
+ ConcatHead, | |
Conv, | |
Conv2, | |
ConvTranspose, | |
@@ -230,11 +231,12 @@ class BaseModel(nn.Module): | |
(BaseModel): An updated BaseModel object. | |
""" | |
self = super()._apply(fn) | |
- m = self.model[-1] # Detect() | |
- if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect | |
- m.stride = fn(m.stride) | |
- m.anchors = fn(m.anchors) | |
- m.strides = fn(m.strides) | |
+ # Get all detect modules | |
+ for m in self.model: | |
+ if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect | |
+ m.stride = fn(m.stride) | |
+ m.anchors = fn(m.anchors) | |
+ m.strides = fn(m.strides) | |
return self | |
def load(self, weights, verbose=True): | |
@@ -289,16 +291,18 @@ class DetectionModel(BaseModel): | |
self.inplace = self.yaml.get("inplace", True) | |
# Build strides | |
- m = self.model[-1] # Detect() | |
- if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect | |
- s = 256 # 2x min stride | |
- m.inplace = self.inplace | |
- forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x) | |
- m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward | |
- self.stride = m.stride | |
- m.bias_init() # only run once | |
- else: | |
- self.stride = torch.Tensor([32]) # default stride for i.e. RTDETR | |
+ # Get all Detect modules | |
+ last_m = len(self.model) - 1 | |
+ for i, m in enumerate(self.model): | |
+ if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect | |
+ s = 256 # 2x min stride | |
+ m.inplace = self.inplace | |
+ forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x) | |
+ m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward | |
+ self.stride = m.stride | |
+ m.bias_init() # only run once | |
+ elif i == last_m: | |
+ self.stride = torch.Tensor([32]) # default stride for i.e. RTDETR | |
# Init weights, biases | |
initialize_weights(self) | |
@@ -895,7 +899,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) | |
args = [ch[f]] | |
elif m is Concat: | |
c2 = sum(ch[x] for x in f) | |
- elif m in (Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn): | |
+ elif m in (Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, ConcatHead): | |
args.append([ch[x] for x in f]) | |
if m is Segment: | |
args[2] = make_divisible(min(args[2], max_channels) * width, 8) | |
-- | |
2.35.1.windows.2 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment