Skip to content

Instantly share code, notes, and snippets.

@ndgnuh
Last active September 2, 2023 16:39
Show Gist options
  • Save ndgnuh/38e36fb5945aaccc4f823762407168d3 to your computer and use it in GitHub Desktop.
Save ndgnuh/38e36fb5945aaccc4f823762407168d3 to your computer and use it in GitHub Desktop.
Hourglass Pytorch
"""
Hourglass network, the backbone part.
Implemented according to the CornetNet paper. The ArXiv version does not have the backbone description.
Reference:
- https://link.springer.com/article/10.1007/s11263-019-01204-1
- https://sci-hub.se/https://link.springer.com/article/10.1007/s11263-019-01204-1 (I won't use 40EUR much money *just* to read a paper. Where I live one can buy 10 books with that.)
- https://arxiv.org/abs/1808.01244 (not really related to this implementation, just in case someone is curious about CornetNet)
Public API:
- Hourglass104
- build_hourglass
"""
from typing import List
from torch import nn
def ConvBR(*args, relu=True, **kwargs):
"""Conv2d, BatchNorm, (maybe) ReLU"""
conv = nn.Conv2d(*args, **kwargs)
norm = nn.BatchNorm2d(conv.out_channels)
if relu:
relu = nn.ReLU(True)
else:
relu = nn.Identity()
return nn.Sequential(conv, norm, relu)
class ResidualBlock(nn.Module):
"""
This is actually the Bottleneck Block in the Resnet paper.
It can be replaced with other Residual Block variants.
"""
def __init__(self, in_channels, out_channels, stride: int = 1, reduction: int = 4):
super().__init__()
hid_channels = in_channels // reduction
self.conv = nn.Sequential(
*ConvBR(in_channels, hid_channels, 1),
*ConvBR(hid_channels, hid_channels, 3, stride, padding=1),
*ConvBR(hid_channels, out_channels, 1),
)
if in_channels != out_channels or stride != 1:
self.skip = nn.Conv2d(in_channels, out_channels, 1, stride=stride)
else:
self.skip = nn.Identity()
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
def __repr__(self):
"""Trust me, this will make the network REPR much more readable"""
return f"ResidualBlock({self.in_channels}, {self.out_channels}, stride={self.stride})"
def forward(self, x):
return self.conv(x) + self.skip(x)
class HGLCore(nn.Module):
"""The core of the hourglass, see `build_hourglass`"""
def __init__(self, hidden_size: int, num_layers: int = 4):
super().__init__()
layers = [ResidualBlock(hidden_size, hidden_size) for _ in range(num_layers)]
self.core = nn.Sequential(*layers)
self.hidden_size = hidden_size
def forward(self, x):
return self.core(x)
def get_hidden_size(self):
return self.hidden_size
class HGLCoreAtn(nn.Module):
"""The core of the hourglass, see `build_hourglass`.
Since the dimension at the core is so low,
it *might* be a good idea to use transformer encoder.
"""
def __init__(self, hidden_size: int, num_layers: int = 4):
super().__init__()
tfm = nn.TransformerEncoderLayer(hidden_size, 4, batch_first=True)
self.core = nn.TransformerEncoder(tfm, num_layers)
self.hidden_size = hidden_size
def forward(self, x):
B, C, H, W = x.shape
x = x.flatten(-2).transpose(-2, -1)
x = self.core(x)
x = x.transpose(-2, -1).reshape(B, C, H, W)
return x
def get_hidden_size(self):
return self.hidden_size
class HGLLevel(nn.Module):
"""The addition layer wrapper for the core or previous layer of hourglass
See `build_hourglass` for details.
"""
def __init__(self, core: nn.Module, in_channels: int):
super().__init__()
self.in_channels = in_channels
hidden_size = core.get_hidden_size()
self.pre_core = nn.Sequential(
ResidualBlock(in_channels, hidden_size, stride=2),
ResidualBlock(hidden_size, hidden_size),
)
self.core = core
self.post_core = nn.Sequential(
ResidualBlock(hidden_size, hidden_size),
ResidualBlock(hidden_size, in_channels),
nn.Upsample(scale_factor=2, mode="bilinear"),
)
self.skip = nn.Sequential(
ResidualBlock(in_channels, in_channels),
ResidualBlock(in_channels, in_channels),
)
def get_hidden_size(self):
return self.in_channels
def forward(self, x):
skip = self.skip(x)
x = self.pre_core(x)
x = self.core(x)
x = self.post_core(x)
x = x + skip
return x
def build_hourglass(
hidden_sizes: List[int] = [256, 256, 256, 384, 384, 512],
attention_core: bool = False,
):
"""Build an hourglass module.
The idea is to wrap the levels of the hourglass one by one.
The inner level is referred to as the `core`.
We wrap the core with some projection layers and and a skip connection:
```
wrapper(hourglass-core, x):
skip <- skip-connection(x)
x <- input-projection(x)
x <- hourglass-core(x)
x <- output-projection(x)
x <- x + skip
return x
```
The first core is `HGLCore`, the wrapper layer is `HGLLevel`.
After wrapping a core, the output becomes a core itself.
This function basically does the following:
```
hourglass = HGLLevel(...HGLLevel(HGLCore(...))...)
```
Args:
hidden_sizes (List[int]):
The channels for each level of the hourglass.
The last entry in `hidden_sizes` will be the hidden_size
in the middle of the hourglass.
attention_core (bool):
If true, `HGLCoreAtn` will be used instead of `HGLCore`.
Default: false.
"""
# The real hour glass is the nested network we build along the way
hourglass = None
# Which core to use
if attention_core:
Core = HGLCoreAtn
else:
Core = HGLCore
# Build along the way
for i, hidden in enumerate(reversed(hidden_sizes)):
if i == 0:
hourglass = Core(hidden)
else:
hourglass = HGLLevel(hourglass, hidden)
return hourglass
class Hourglass104(nn.Module):
"""Hourglass104 network.
This module only returns the final and the middle feature maps,
in that order, and not the predictions.
Side-note:
Since this implementation use Bottleneck block, it's actually
Hourglass 138.
The type of block to use, I think that, in the end, it doesn't even matter.
One can swap out the basic building block for an equivalent one and still
have same overall architecture.
I named this thing 104 because that's the way it is from the paper.
Maybe the generalized version (in which blocks can be swapped) should be
called something else.
Args:
hidden_sizes (List[int]):
Input to `build_hourglass`. The first hidden size
will be the channels of the output feature maps.
Default: [256, 384, 384, 384, 512, 512].
By changing the default value, you can change the
computation requirement of the model and maybe save some FLOPS.
attention_core (bool):
Whether to use `HGLCoreAtn`. Default: `False`.
"""
def __init__(
self,
hidden_sizes: List[int] = [256, 384, 384, 384, 512, 512],
attention_core: bool = False,
):
super().__init__()
h0 = hidden_sizes[0]
self.stem = nn.Sequential(
*ConvBR(3, h0 // 2, 7, 2, padding=3),
*ConvBR(h0 // 2, h0, 3, 2, padding=1),
)
# First hourglass
self.hourglass_1 = nn.Sequential(
build_hourglass(hidden_sizes, attention_core),
ConvBR(h0, h0, 3, padding=1),
)
self.skip_1 = ConvBR(h0, h0, 1, relu=False)
self.project_1 = ConvBR(h0, h0, 1, relu=False)
# Second hourglass
self.hourglass_2 = nn.Sequential(
nn.ReLU(True),
ConvBR(h0, h0, 3, padding=1),
build_hourglass(hidden_sizes, attention_core),
ConvBR(h0, h0, 3, padding=1),
)
def forward(self, x):
# Prepare
x = self.stem(x)
outputs = []
# First hourglass forward
skip = self.skip_1(x)
x = self.hourglass_1(x)
outputs.insert(0, x)
x = self.project_1(x) + skip
# Second hourglass forward
x = self.hourglass_2(x)
outputs.insert(0, x)
# Return two feature maps
return outputs
@ndgnuh
Copy link
Author

ndgnuh commented Aug 25, 2023

Fig related

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment