Skip to content

Instantly share code, notes, and snippets.

@Y-T-G
Created May 27, 2024 11:47
Show Gist options
  • Save Y-T-G/3b62416a6439a385e743d62f0d0ef842 to your computer and use it in GitHub Desktop.
Save Y-T-G/3b62416a6439a385e743d62f0d0ef842 to your computer and use it in GitHub Desktop.
YOLOv8 port of CBAM and Involution modules by @aash1999
From 1bc48504b6ecebcdc19242f9e99adf5079e7d568 Mon Sep 17 00:00:00 2001
From: Y-T-G <>
Date: Mon, 27 May 2024 19:44:28 +0800
Subject: [PATCH] Port CBAM and Involution by @aash1999
---
.../models/v8/yolov8-cbam-involution-p2.yaml | 57 ++++++++
.../cfg/models/v8/yolov8-cbam-involution.yaml | 49 +++++++
ultralytics/nn/modules/__init__.py | 2 +
ultralytics/nn/modules/conv.py | 123 +++++++++++++++---
ultralytics/nn/tasks.py | 4 +
5 files changed, 219 insertions(+), 16 deletions(-)
create mode 100644 ultralytics/cfg/models/v8/yolov8-cbam-involution-p2.yaml
create mode 100644 ultralytics/cfg/models/v8/yolov8-cbam-involution.yaml
diff --git a/ultralytics/cfg/models/v8/yolov8-cbam-involution-p2.yaml b/ultralytics/cfg/models/v8/yolov8-cbam-involution-p2.yaml
new file mode 100644
index 0000000..5acad99
--- /dev/null
+++ b/ultralytics/cfg/models/v8/yolov8-cbam-involution-p2.yaml
@@ -0,0 +1,57 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+# YOLOv8 object detection model with P2-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
+
+# Parameters
+nc: 80 # number of classes
+scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
+ # [depth, width, max_channels]
+ n: [0.33, 0.25, 1024]
+ s: [0.33, 0.50, 1024]
+ m: [0.67, 0.75, 768]
+ l: [1.00, 1.00, 512]
+ x: [1.00, 1.25, 512]
+
+# YOLOv8.0 backbone
+backbone:
+ # [from, repeats, module, args]
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
+ - [-1, 3, C2f, [128, True]]
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
+ - [-1, 6, C2f, [256, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
+ - [-1, 6, C2f, [512, True]]
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
+ - [-1, 3, C2f, [1024, True]]
+ - [-1, 3, CBAM, [1024, 3]]
+ - [-1, 1, SPPF, [1024, 5]] # 10
+
+# YOLOv8.0-p2 head
+head:
+ - [-1, 1, Involution, [1024, 1, 1]]
+ - [-1, 1, Conv, [1024, 1, 1]]
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
+ - [-1, 3, C2f, [512]] # 15
+
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
+ - [-1, 3, C2f, [256]] # 18 (P3/8-small)
+
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
+ - [[-1, 2], 1, Concat, [1]] # cat backbone P2
+ - [-1, 3, C2f, [128]] # 21 (P2/4-xsmall)
+
+ - [-1, 1, Conv, [128, 3, 2]]
+ - [[-1, 18], 1, Concat, [1]] # cat head P3
+ - [-1, 3, C2f, [256]] # 24 (P3/8-small)
+
+ - [-1, 1, Conv, [256, 3, 2]]
+ - [[-1, 15], 1, Concat, [1]] # cat head P4
+ - [-1, 3, C2f, [512]] # 27 (P4/16-medium)
+
+ - [-1, 1, Conv, [512, 3, 2]]
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
+ - [-1, 3, C2f, [1024]] # 30 (P5/32-large)
+
+ - [[21, 24, 27, 30], 1, Detect, [nc]] # Detect(P2, P3, P4)
diff --git a/ultralytics/cfg/models/v8/yolov8-cbam-involution.yaml b/ultralytics/cfg/models/v8/yolov8-cbam-involution.yaml
new file mode 100644
index 0000000..6b6f053
--- /dev/null
+++ b/ultralytics/cfg/models/v8/yolov8-cbam-involution.yaml
@@ -0,0 +1,49 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
+
+# Parameters
+nc: 80 # number of classes
+scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
+ # [depth, width, max_channels]
+ n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
+ s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
+ m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
+ l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
+ x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
+
+# YOLOv8.0n backbone
+backbone:
+ # [from, repeats, module, args]
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
+ - [-1, 3, C2f, [128, True]]
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
+ - [-1, 6, C2f, [256, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
+ - [-1, 6, C2f, [512, True]]
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
+ - [-1, 3, C2f, [1024, True]]
+ - [-1, 3, CBAM, [1024, 3]]
+ - [-1, 1, SPPF, [1024, 5]] # 10
+
+# YOLOv8.0n head
+head:
+ - [-1, 1, Involution, [1024, 1, 1]]
+ - [-1, 1, Conv, [1024, 1, 1]]
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
+ - [-1, 3, C2f, [512]] # 15
+
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
+ - [-1, 3, C2f, [256]] # 18 (P3/8-small)
+
+ - [-1, 1, Conv, [256, 3, 2]]
+ - [[-1, 15], 1, Concat, [1]] # cat head P4
+ - [-1, 3, C2f, [512]] # 21 (P4/16-medium)
+
+ - [-1, 1, Conv, [512, 3, 2]]
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
+ - [-1, 3, C2f, [1024]] # 24 (P5/32-large)
+
+ - [[18, 21, 24], 1, Detect, [nc]] # Detect(P3, P4, P5)
diff --git a/ultralytics/nn/modules/__init__.py b/ultralytics/nn/modules/__init__.py
index 5104417..5309dfe 100644
--- a/ultralytics/nn/modules/__init__.py
+++ b/ultralytics/nn/modules/__init__.py
@@ -61,6 +61,7 @@ from .conv import (
LightConv,
RepConv,
SpatialAttention,
+ Involution
)
from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment, WorldDetect
from .transformer import (
@@ -135,4 +136,5 @@ __all__ = (
"CBFuse",
"CBLinear",
"Silence",
+ "Involution"
)
diff --git a/ultralytics/nn/modules/conv.py b/ultralytics/nn/modules/conv.py
index 6b51813..45695a5 100644
--- a/ultralytics/nn/modules/conv.py
+++ b/ultralytics/nn/modules/conv.py
@@ -7,6 +7,8 @@ import numpy as np
import torch
import torch.nn as nn
+import warnings
+
__all__ = (
"Conv",
"Conv2",
@@ -21,6 +23,7 @@ __all__ = (
"CBAM",
"Concat",
"RepConv",
+ "Involution"
)
@@ -275,19 +278,38 @@ class RepConv(nn.Module):
self.__delattr__("id_tensor")
+# contributed by @aash1999
class ChannelAttention(nn.Module):
- """Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""
- def __init__(self, channels: int) -> None:
- """Initializes the class and sets the basic configurations and instance variables required."""
+ def __init__(self, in_planes, ratio=16):
+ """
+ Initialize the Channel Attention module.
+ Args:
+ in_planes (int): Number of input channels.
+ ratio (int): Reduction ratio for the hidden channels in the channel attention block.
+ """
super().__init__()
- self.pool = nn.AdaptiveAvgPool2d(1)
- self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
- self.act = nn.Sigmoid()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
+ self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
+ self.relu = nn.ReLU()
+ self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
+ self.sigmoid = nn.Sigmoid()
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Applies forward pass using activation on convolutions of the input, optionally using batch normalization."""
- return x * self.act(self.fc(self.pool(x)))
+ def forward(self, x):
+ """
+ Forward pass of the Channel Attention module.
+ Args:
+ x (torch.Tensor): Input tensor.
+ Returns:
+ out (torch.Tensor): Output tensor after applying channel attention.
+ """
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore')
+ avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
+ max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
+ out = self.sigmoid(avg_out + max_out)
+ return out
class SpatialAttention(nn.Module):
@@ -306,18 +328,43 @@ class SpatialAttention(nn.Module):
return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))
+# contributed by @aash1999
class CBAM(nn.Module):
- """Convolutional Block Attention Module."""
-
- def __init__(self, c1, kernel_size=7):
- """Initialize CBAM with given input channel (c1) and kernel size."""
+ # ch_in, ch_out, shortcut, groups, expansion, ratio, kernel_size
+ def __init__(self, c1, c2, kernel_size=3, shortcut=True, g=1, e=0.5, ratio=16):
+ """
+ Initialize the CBAM (Convolutional Block Attention Module) .
+ Args:
+ c1 (int): Number of input channels.
+ c2 (int): Number of output channels.
+ kernel_size (int): Size of the convolutional kernel.
+ shortcut (bool): Whether to use a shortcut connection.
+ g (int): Number of groups for grouped convolutions.
+ e (float): Expansion factor for hidden channels.
+ ratio (int): Reduction ratio for the hidden channels in the channel attention block.
+ """
super().__init__()
- self.channel_attention = ChannelAttention(c1)
+ c_ = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = Conv(c_, c2, 3, 1, g=g)
+ self.add = shortcut and c1 == c2
+ self.channel_attention = ChannelAttention(c2, ratio)
self.spatial_attention = SpatialAttention(kernel_size)
def forward(self, x):
- """Applies the forward pass through C1 module."""
- return self.spatial_attention(self.channel_attention(x))
+ """
+ Forward pass of the CBAM .
+ Args:
+ x (torch.Tensor): Input tensor.
+ Returns:
+ out (torch.Tensor): Output tensor after applying the CBAM bottleneck.
+ """
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore')
+ x2 = self.cv2(self.cv1(x))
+ out = self.channel_attention(x2) * x2
+ out = self.spatial_attention(out) * out
+ return x + out if self.add else out
class Concat(nn.Module):
@@ -331,3 +378,47 @@ class Concat(nn.Module):
def forward(self, x):
"""Forward pass for the YOLOv8 mask Proto module."""
return torch.cat(x, self.d)
+
+# contributed by @aash1999
+class Involution(nn.Module):
+
+ def __init__(self, c1, c2, kernel_size, stride):
+ """
+ Initialize the Involution module.
+ Args:
+ c1 (int): Number of input channels.
+ c2 (int): Number of output channels.
+ kernel_size (int): Size of the involution kernel.
+ stride (int): Stride for the involution operation.
+ """
+ super().__init__()
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.c1 = c1
+ reduction_ratio = 1
+ self.group_channels = 16
+ self.groups = self.c1 // self.group_channels
+ self.conv1 = Conv(c1, c1 // reduction_ratio, 1)
+ self.conv2 = Conv(c1 // reduction_ratio, kernel_size ** 2 * self.groups, 1, 1)
+
+ if stride > 1:
+ self.avgpool = nn.AvgPool2d(stride, stride)
+ self.unfold = nn.Unfold(kernel_size, 1, (kernel_size - 1) // 2, stride)
+
+ def forward(self, x):
+ """
+ Forward pass of the Involution module.
+ Args:
+ x (torch.Tensor): Input tensor.
+ Returns:
+ out (torch.Tensor): Output tensor after applying the involution operation.
+ """
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore')
+ weight = self.conv2(x)
+ b, c, h, w = weight.shape
+ weight = weight.view(b, self.groups, self.kernel_size ** 2, h, w).unsqueeze(2)
+ out = self.unfold(x).view(b, self.groups, self.group_channels, self.kernel_size ** 2, h, w)
+ out = (weight * out).sum(dim=3).view(b, self.c1, h, w)
+
+ return out
\ No newline at end of file
diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py
index cf3816f..c0a218b 100644
--- a/ultralytics/nn/tasks.py
+++ b/ultralytics/nn/tasks.py
@@ -49,6 +49,8 @@ from ultralytics.nn.modules import (
Segment,
Silence,
WorldDetect,
+ CBAM,
+ Involution
)
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
@@ -885,6 +887,8 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
DWConvTranspose2d,
C3x,
RepC3,
+ CBAM,
+ Involution
}:
c1, c2 = ch[f], args[0]
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
--
2.34.1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment