-
-
Save tokudayo/e8f876b74d84310e9d9028db36ad4681 to your computer and use it in GitHub Desktop.
Model snapshot after "large kernels" step in ConvNeXt paper
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional, Any, List | |
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
class StochasticDepth(nn.Module): | |
"""Randomly drop a module""" | |
def __init__(self, module: nn.Module, survival_rate: float = 1.) -> None: | |
super().__init__() | |
self.module = module | |
self.survival_rate = survival_rate | |
self._drop = torch.distributions.Bernoulli(torch.tensor(1 - survival_rate)) | |
def forward(self, x: Tensor) -> Tensor: | |
return 0 if self.training and self._drop.sample() else self.module(x) | |
def __repr__(self) -> str: | |
return self.module.__repr__() + f", stodepth_survival_rate={self.survival_rate:.2f}" | |
def dwconv7x7(planes: int, stride: int = 1) -> nn.Conv2d: | |
return nn.Conv2d(planes, planes, kernel_size=7, stride=stride, padding=3, groups=planes, bias=False) | |
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: | |
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |
norm = nn.BatchNorm2d | |
relu = lambda : nn.ReLU(inplace=True) | |
class Block(nn.Module): | |
expansion: int = 4 | |
def __init__( | |
self, | |
inplanes: int, | |
width: int, | |
stride: int = 1, | |
projection: Optional[nn.Module] = None, | |
stodepth_survive: float = 1. | |
) -> None: | |
super().__init__() | |
self.relu = relu() | |
self.projection = projection | |
expanded = width * self.expansion | |
main_path = nn.Sequential( | |
dwconv7x7(inplanes, stride), norm(inplanes), relu(), | |
conv1x1(inplanes, expanded), norm(expanded), relu(), | |
conv1x1(expanded, width), norm(width), | |
) | |
self.main_path = StochasticDepth(main_path, stodepth_survive) if stodepth_survive < 1. else main_path | |
def forward(self, x: Tensor) -> Tensor: | |
out = self.main_path(x) | |
identity = x if self.projection is None else self.projection(x) | |
out = self.relu(out + identity) | |
return out | |
class ResNet(nn.Module): | |
def __init__(self, layers: List[int], num_classes: int = 1000, stodepth_survive: float = 1.) -> None: | |
super().__init__() | |
widths = [96, 192, 384, 768] | |
self.inplanes = widths[0] | |
self.stodepth = stodepth_survive | |
# Downsampling stem downsamples input size by 4, e.g. 224 -> 56 | |
self.stem = nn.Sequential( | |
nn.Conv2d(3, self.inplanes, kernel_size=4, stride=4, bias=False), | |
norm(self.inplanes), | |
) | |
# Res1 -> Res4. No downsampling at the beginning of Res1. | |
self.stages = nn.Sequential( | |
*[self._make_stage(widths[i], layers[i], stride=2 if i != 0 else 1) for i in range(4)] | |
) | |
# Classification head | |
self.head = nn.Sequential( | |
nn.AdaptiveAvgPool2d((1, 1)), | |
nn.Flatten(), | |
nn.Linear(widths[-1], num_classes) | |
) | |
# Initialize weights | |
self.apply(self._init_weights) | |
def _init_weights(self, m: nn.Module) -> None: | |
if isinstance(m, nn.Conv2d): | |
nn.init.kaiming_normal_(m.weight, mode="fan_out") | |
elif isinstance(m, nn.BatchNorm2d): | |
nn.init.ones_(m.weight) | |
nn.init.zeros_(m.bias) | |
elif isinstance(m, nn.Linear): | |
init_range = 1.0 / (m.out_features ** 0.5) | |
nn.init.uniform_(m.weight, -init_range, init_range) | |
nn.init.zeros_(m.bias) | |
def _make_stage(self, width : int, num_blocks: int, stride: int = 1) -> nn.Sequential: | |
blocks = [] | |
# Projection where needed | |
projection = nn.Sequential( | |
conv1x1(self.inplanes, width, stride), | |
norm(width) | |
) if stride != 1 or self.inplanes != width else None | |
blocks.append( | |
Block(self.inplanes, width, stride=stride, projection=projection, stodepth_survive=self.stodepth) | |
) | |
# Remaining blocks of the stage | |
self.inplanes = width | |
for _ in range(1, num_blocks): | |
blocks.append(Block(self.inplanes, width, stride=1, projection=None)) | |
return nn.Sequential(*blocks) | |
def forward(self, x: Tensor) -> Tensor: | |
x = self.stem(x) | |
x = self.stages(x) | |
x = self.head(x) | |
return x | |
def _resnet(layers: List[int], **kwargs: Any) -> ResNet: | |
model = ResNet(layers, **kwargs) | |
return model | |
def resnet50(**kwargs: Any) -> ResNet: | |
return _resnet([3, 3, 9, 3], **kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment