Skip to content

Instantly share code, notes, and snippets.

@tokudayo
Last active April 5, 2022 09:02
Show Gist options
  • Save tokudayo/c59277ae0d9637d2d325da2524476947 to your computer and use it in GitHub Desktop.
Save tokudayo/c59277ae0d9637d2d325da2524476947 to your computer and use it in GitHub Desktop.
ConvNeXt architecture
from typing import Any, List
import torch
import torch.nn as nn
from torch import Tensor
class LayerNorm(nn.LayerNorm):
"""Permute the input tensor so that the channel dimension is the last one."""
def __init__(self, num_features: int, eps: float = 1e-6, **kwargs: Any) -> None:
super().__init__(num_features, eps=eps, **kwargs)
def forward(self, x: Tensor) -> Tensor:
return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
class StochasticDepth(nn.Module):
"""Randomly drop a module"""
def __init__(self, module: nn.Module, survival_rate: float = 1.) -> None:
super().__init__()
self.module = module
self.survival_rate = survival_rate
self._drop = torch.distributions.Bernoulli(torch.tensor(1 - survival_rate))
def forward(self, x: Tensor) -> Tensor:
return 0 if self.training and self._drop.sample() else self.module(x)
def __repr__(self) -> str:
return self.module.__repr__() + f", stodepth_survival_rate={self.survival_rate:.2f}"
def dwconv7x7(planes: int, stride: int = 1) -> nn.Conv2d:
return nn.Conv2d(planes, planes, kernel_size=7, stride=stride, padding=3, groups=planes)
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride)
norm = LayerNorm
gelu = lambda : nn.GELU()
class Block(nn.Module):
expansion: int = 4
def __init__(self, width: int, stodepth_survive: float = 1.) -> None:
super().__init__()
expanded = width * self.expansion
main_path = nn.Sequential(
dwconv7x7(width),
norm(width),
conv1x1(width, expanded),
gelu(),
conv1x1(expanded, width),
)
self.main_path = StochasticDepth(main_path, stodepth_survive) if stodepth_survive < 1. else main_path
def forward(self, x: Tensor) -> Tensor:
return x + self.main_path(x)
class ConvNext(nn.Module):
def __init__(self, base_width: int, layers: List[int], num_classes: int = 1000, stodepth_survive: float = 1.) -> None:
super().__init__()
widths = [base_width * (2**i) for i in range(4)]
self.inplanes = widths[0]
self.stodepth = stodepth_survive
# Downsampling stem downsamples input size by 4, e.g. 224 -> 56
self.stem = nn.Sequential(
nn.Conv2d(3, self.inplanes, kernel_size=4, stride=4),
norm(self.inplanes),
)
# Stage 1 -> 4 and intermediate downsampling layers
for idx, (layer, width) in enumerate(zip(layers, widths)):
self.add_module(
f"stage{idx + 1}",
nn.Sequential(*[Block(width, stodepth_survive) for _ in range(layer)])
)
if idx == 3: break
self.add_module(
f"ds{idx + 1}",
nn.Sequential(
norm(width),
nn.Conv2d(width, widths[idx + 1], kernel_size=2, stride=2),
)
)
# Classification head
self.head = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
norm(widths[-1]),
nn.Flatten(),
nn.Linear(widths[-1], num_classes)
)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, m) -> None:
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
init_range = 1.0 / (m.out_features ** 0.5)
nn.init.uniform_(m.weight, -init_range, init_range)
nn.init.zeros_(m.bias)
def forward(self, x: Tensor) -> Tensor:
x = self.stem(x)
x = self.stage1(x)
x = self.ds1(x)
x = self.stage2(x)
x = self.ds2(x)
x = self.stage3(x)
x = self.ds3(x)
x = self.stage4(x)
x = self.head(x)
return x
def _convnext(base_width: int, layers: List[int], **kwargs: Any) -> ConvNext:
model = ConvNext(base_width, layers, **kwargs)
return model
def convnext_t(**kwargs: Any) -> ConvNext:
return _convnext(96, [3, 3, 9, 3], **kwargs)
def convnext_s(**kwargs: Any) -> ConvNext:
return _convnext(96, [3, 3, 27, 3], **kwargs)
def convnext_b(**kwargs: Any) -> ConvNext:
return _convnext(128, [3, 3, 27, 3], **kwargs)
def convnext_l(**kwargs: Any) -> ConvNext:
return _convnext(192, [3, 3, 27, 3], **kwargs)
def convnext_xl(**kwargs: Any) -> ConvNext:
return _convnext(256, [3, 3, 27, 3], **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment