Skip to content

Instantly share code, notes, and snippets.

@Y-T-G
Last active June 18, 2024 03:24
Show Gist options
  • Save Y-T-G/8f4fc0b78a0a559a06fe84ae4f359e6e to your computer and use it in GitHub Desktop.
Save Y-T-G/8f4fc0b78a0a559a06fe84ae4f359e6e to your computer and use it in GitHub Desktop.
Patches the YOLO ultralytics library to support extra head for additional classes.
From e41525e084a6f77980665da715ec93070c397479 Mon Sep 17 00:00:00 2001
From: Y-T-G <>
Date: Sat, 16 Mar 2024 14:06:25 +0800
Subject: [PATCH] Add ConcatHead
---
ultralytics/cfg/models/v8/yolov8n-2xhead.yaml | 55 ++++++++++++++++
ultralytics/engine/trainer.py | 2 +-
ultralytics/nn/modules/__init__.py | 2 +
ultralytics/nn/modules/conv.py | 64 +++++++++++++++++++
ultralytics/nn/tasks.py | 36 ++++++-----
5 files changed, 142 insertions(+), 17 deletions(-)
create mode 100644 ultralytics/cfg/models/v8/yolov8n-2xhead.yaml
diff --git a/ultralytics/cfg/models/v8/yolov8n-2xhead.yaml b/ultralytics/cfg/models/v8/yolov8n-2xhead.yaml
new file mode 100644
index 0000000..a4a0b50
--- /dev/null
+++ b/ultralytics/cfg/models/v8/yolov8n-2xhead.yaml
@@ -0,0 +1,55 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
+
+# Parameters
+nc: 82 # number of classes
+scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
+ # [depth, width, max_channels]
+ n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
+ s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
+ m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
+ l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
+ x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
+
+# YOLOv8.0n backbone
+backbone:
+ # [from, repeats, module, args]
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
+ - [-1, 3, C2f, [128, True]]
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
+ - [-1, 6, C2f, [256, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
+ - [-1, 6, C2f, [512, True]]
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
+ - [-1, 3, C2f, [1024, True]]
+ - [-1, 1, SPPF, [1024, 5]] # 9
+
+# YOLOv8.0n head
+head:
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
+ - [-1, 3, C2f, [512]] # 12
+
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
+ - [-1, 3, C2f, [256]] # 15 (P3/8-small)
+
+ - [-1, 1, Conv, [256, 3, 2]]
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
+ - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
+
+ - [-1, 1, Conv, [512, 3, 2]]
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
+ - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
+
+ # Layer 22. First list indicates the layers whose output are used as input.
+ # We indicate that this head has 80 outputs.
+ - [[15, 18, 21], 1, Detect, [80]] # Detect(P3, P4, P5)
+
+ # Layer 23. We indicate that this head has 2 outputs, because that's how
+ # many new classes we added.
+ - [[15, 18, 21], 1, Detect, [2]] # Detect(P3, P4, P5) #23 new classes
+
+ # ConcatHead takes in layer 22 and 23 and concatenates the output.
+ - [[22, 23], 1, ConcatHead, [80, 2]] # Concat #22 and #23
diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py
index f005f34..4a025a2 100644
--- a/ultralytics/engine/trainer.py
+++ b/ultralytics/engine/trainer.py
@@ -340,8 +340,8 @@ class BaseTrainer:
epoch = self.start_epoch
while True:
self.epoch = epoch
- self.run_callbacks("on_train_epoch_start")
self.model.train()
+ self.run_callbacks("on_train_epoch_start")
if RANK != -1:
self.train_loader.sampler.set_epoch(epoch)
pbar = enumerate(self.train_loader)
diff --git a/ultralytics/nn/modules/__init__.py b/ultralytics/nn/modules/__init__.py
index d785c00..36208cc 100644
--- a/ultralytics/nn/modules/__init__.py
+++ b/ultralytics/nn/modules/__init__.py
@@ -51,6 +51,7 @@ from .conv import (
CBAM,
ChannelAttention,
Concat,
+ ConcatHead,
Conv,
Conv2,
ConvTranspose,
@@ -90,6 +91,7 @@ __all__ = (
"SpatialAttention",
"CBAM",
"Concat",
+ "ConcatHead",
"TransformerLayer",
"TransformerBlock",
"MLPBlock",
diff --git a/ultralytics/nn/modules/conv.py b/ultralytics/nn/modules/conv.py
index 399c422..24dc844 100644
--- a/ultralytics/nn/modules/conv.py
+++ b/ultralytics/nn/modules/conv.py
@@ -20,6 +20,7 @@ __all__ = (
"SpatialAttention",
"CBAM",
"Concat",
+ "ConcatHead",
"RepConv",
)
@@ -331,3 +332,66 @@ class Concat(nn.Module):
def forward(self, x):
"""Forward pass for the YOLOv8 mask Proto module."""
return torch.cat(x, self.d)
+
+
+class ConcatHead(nn.Module):
+ """Concatenaion layer for Detect heads."""
+
+ def __init__(self, nc1=80, nc2=1, ch=()):
+ """Initializes the ConcatHead."""
+ super().__init__()
+ self.nc1 = nc1 # number of classes of head 1
+ self.nc2 = nc2 # number of classes of head 2
+
+ def forward(self, x):
+ """Concatenates and returns predicted bounding boxes and class probabilities."""
+
+ # x is a list of length 2
+ # Each element is either a tuple or just the decoded features
+ # depending whether it's being exported.
+ # First element of tuple are the decoded preds,
+ # second element are feature maps for heatmap visualization
+
+ if isinstance(x[0], tuple):
+ preds1 = x[0][0]
+ preds2 = x[1][0]
+ elif isinstance(x[0], list): # when returned raw outputs
+ # The shape is used for stride creation in tasks.py.
+ # Feature maps will have to be decoded individually if used as they can't be merged.
+ return [torch.cat((x0, x1), dim=1) for x0, x1 in zip(x[0], x[1])]
+ else:
+ preds1 = x[0]
+ preds2 = x[1]
+
+ # Concatenate the new head outputs as extra outputs
+
+ # 1. Concatenate bbox outputs
+ # Shape changes from [N, 4, 6300] to [N, 4, 12600]
+ preds = torch.cat((preds1[:, :4, :], preds2[:, :4, :]), dim=2)
+
+ # 2. Concatenate class outputs
+ # Append preds 1 with empty outputs of size 6300
+ shape = list(preds1.shape)
+ shape[-1] *= 2
+
+ preds1_extended = torch.zeros(shape, device=preds1.device,
+ dtype=preds1.dtype)
+ preds1_extended[..., : preds1.shape[-1]] = preds1
+
+ # Prepend preds 2 with empty outputs of size 6300
+ shape = list(preds2.shape)
+ shape[-1] *= 2
+
+ preds2_extended = torch.zeros(shape, device=preds2.device,
+ dtype=preds2.dtype)
+ preds2_extended[..., preds2.shape[-1] :] = preds2
+
+ # Arrange the class probabilities in order preds1, preds2. The
+ # class indices of preds2 will therefore start after preds1
+ preds = torch.cat((preds, preds1_extended[:, 4:, :]), dim=1)
+ preds = torch.cat((preds, preds2_extended[:, 4:, :]), dim=1)
+
+ if isinstance(x[0], tuple):
+ return (preds, x[0][1])
+ else:
+ return preds
diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py
index 64ee7f5..c8a3e14 100644
--- a/ultralytics/nn/tasks.py
+++ b/ultralytics/nn/tasks.py
@@ -25,6 +25,7 @@ from ultralytics.nn.modules import (
C3x,
Classify,
Concat,
+ ConcatHead,
Conv,
Conv2,
ConvTranspose,
@@ -230,11 +231,12 @@ class BaseModel(nn.Module):
(BaseModel): An updated BaseModel object.
"""
self = super()._apply(fn)
- m = self.model[-1] # Detect()
- if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
- m.stride = fn(m.stride)
- m.anchors = fn(m.anchors)
- m.strides = fn(m.strides)
+ # Get all detect modules
+ for m in self.model:
+ if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
+ m.stride = fn(m.stride)
+ m.anchors = fn(m.anchors)
+ m.strides = fn(m.strides)
return self
def load(self, weights, verbose=True):
@@ -289,16 +291,18 @@ class DetectionModel(BaseModel):
self.inplace = self.yaml.get("inplace", True)
# Build strides
- m = self.model[-1] # Detect()
- if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
- s = 256 # 2x min stride
- m.inplace = self.inplace
- forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
- m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
- self.stride = m.stride
- m.bias_init() # only run once
- else:
- self.stride = torch.Tensor([32]) # default stride for i.e. RTDETR
+ # Get all Detect modules
+ last_m = len(self.model) - 1
+ for i, m in enumerate(self.model):
+ if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
+ s = 256 # 2x min stride
+ m.inplace = self.inplace
+ forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
+ m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
+ self.stride = m.stride
+ m.bias_init() # only run once
+ elif i == last_m:
+ self.stride = torch.Tensor([32]) # default stride for i.e. RTDETR
# Init weights, biases
initialize_weights(self)
@@ -895,7 +899,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
args = [ch[f]]
elif m is Concat:
c2 = sum(ch[x] for x in f)
- elif m in (Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn):
+ elif m in (Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, ConcatHead):
args.append([ch[x] for x in f])
if m is Segment:
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
--
2.35.1.windows.2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment