-
-
Save tokudayo/c59277ae0d9637d2d325da2524476947 to your computer and use it in GitHub Desktop.
ConvNeXt architecture
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any, List | |
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
class LayerNorm(nn.LayerNorm): | |
"""Permute the input tensor so that the channel dimension is the last one.""" | |
def __init__(self, num_features: int, eps: float = 1e-6, **kwargs: Any) -> None: | |
super().__init__(num_features, eps=eps, **kwargs) | |
def forward(self, x: Tensor) -> Tensor: | |
return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | |
class StochasticDepth(nn.Module): | |
"""Randomly drop a module""" | |
def __init__(self, module: nn.Module, survival_rate: float = 1.) -> None: | |
super().__init__() | |
self.module = module | |
self.survival_rate = survival_rate | |
self._drop = torch.distributions.Bernoulli(torch.tensor(1 - survival_rate)) | |
def forward(self, x: Tensor) -> Tensor: | |
return 0 if self.training and self._drop.sample() else self.module(x) | |
def __repr__(self) -> str: | |
return self.module.__repr__() + f", stodepth_survival_rate={self.survival_rate:.2f}" | |
def dwconv7x7(planes: int, stride: int = 1) -> nn.Conv2d: | |
return nn.Conv2d(planes, planes, kernel_size=7, stride=stride, padding=3, groups=planes) | |
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: | |
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride) | |
norm = LayerNorm | |
gelu = lambda : nn.GELU() | |
class Block(nn.Module): | |
expansion: int = 4 | |
def __init__(self, width: int, stodepth_survive: float = 1.) -> None: | |
super().__init__() | |
expanded = width * self.expansion | |
main_path = nn.Sequential( | |
dwconv7x7(width), | |
norm(width), | |
conv1x1(width, expanded), | |
gelu(), | |
conv1x1(expanded, width), | |
) | |
self.main_path = StochasticDepth(main_path, stodepth_survive) if stodepth_survive < 1. else main_path | |
def forward(self, x: Tensor) -> Tensor: | |
return x + self.main_path(x) | |
class ConvNext(nn.Module): | |
def __init__(self, base_width: int, layers: List[int], num_classes: int = 1000, stodepth_survive: float = 1.) -> None: | |
super().__init__() | |
widths = [base_width * (2**i) for i in range(4)] | |
self.inplanes = widths[0] | |
self.stodepth = stodepth_survive | |
# Downsampling stem downsamples input size by 4, e.g. 224 -> 56 | |
self.stem = nn.Sequential( | |
nn.Conv2d(3, self.inplanes, kernel_size=4, stride=4), | |
norm(self.inplanes), | |
) | |
# Stage 1 -> 4 and intermediate downsampling layers | |
for idx, (layer, width) in enumerate(zip(layers, widths)): | |
self.add_module( | |
f"stage{idx + 1}", | |
nn.Sequential(*[Block(width, stodepth_survive) for _ in range(layer)]) | |
) | |
if idx == 3: break | |
self.add_module( | |
f"ds{idx + 1}", | |
nn.Sequential( | |
norm(width), | |
nn.Conv2d(width, widths[idx + 1], kernel_size=2, stride=2), | |
) | |
) | |
# Classification head | |
self.head = nn.Sequential( | |
nn.AdaptiveAvgPool2d((1, 1)), | |
norm(widths[-1]), | |
nn.Flatten(), | |
nn.Linear(widths[-1], num_classes) | |
) | |
# Initialize weights | |
self.apply(self._init_weights) | |
def _init_weights(self, m) -> None: | |
if isinstance(m, nn.Conv2d): | |
nn.init.kaiming_normal_(m.weight, mode="fan_out") | |
if m.bias is not None: | |
nn.init.zeros_(m.bias) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.ones_(m.weight) | |
nn.init.zeros_(m.bias) | |
elif isinstance(m, nn.Linear): | |
init_range = 1.0 / (m.out_features ** 0.5) | |
nn.init.uniform_(m.weight, -init_range, init_range) | |
nn.init.zeros_(m.bias) | |
def forward(self, x: Tensor) -> Tensor: | |
x = self.stem(x) | |
x = self.stage1(x) | |
x = self.ds1(x) | |
x = self.stage2(x) | |
x = self.ds2(x) | |
x = self.stage3(x) | |
x = self.ds3(x) | |
x = self.stage4(x) | |
x = self.head(x) | |
return x | |
def _convnext(base_width: int, layers: List[int], **kwargs: Any) -> ConvNext: | |
model = ConvNext(base_width, layers, **kwargs) | |
return model | |
def convnext_t(**kwargs: Any) -> ConvNext: | |
return _convnext(96, [3, 3, 9, 3], **kwargs) | |
def convnext_s(**kwargs: Any) -> ConvNext: | |
return _convnext(96, [3, 3, 27, 3], **kwargs) | |
def convnext_b(**kwargs: Any) -> ConvNext: | |
return _convnext(128, [3, 3, 27, 3], **kwargs) | |
def convnext_l(**kwargs: Any) -> ConvNext: | |
return _convnext(192, [3, 3, 27, 3], **kwargs) | |
def convnext_xl(**kwargs: Any) -> ConvNext: | |
return _convnext(256, [3, 3, 27, 3], **kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment