Skip to content

Instantly share code, notes, and snippets.

@tokudayo
Last active April 4, 2022 18:21
Show Gist options
  • Save tokudayo/e8f876b74d84310e9d9028db36ad4681 to your computer and use it in GitHub Desktop.
Save tokudayo/e8f876b74d84310e9d9028db36ad4681 to your computer and use it in GitHub Desktop.
Model snapshot after "large kernels" step in ConvNeXt paper
from typing import Optional, Any, List
import torch
import torch.nn as nn
from torch import Tensor
class StochasticDepth(nn.Module):
"""Randomly drop a module"""
def __init__(self, module: nn.Module, survival_rate: float = 1.) -> None:
super().__init__()
self.module = module
self.survival_rate = survival_rate
self._drop = torch.distributions.Bernoulli(torch.tensor(1 - survival_rate))
def forward(self, x: Tensor) -> Tensor:
return 0 if self.training and self._drop.sample() else self.module(x)
def __repr__(self) -> str:
return self.module.__repr__() + f", stodepth_survival_rate={self.survival_rate:.2f}"
def dwconv7x7(planes: int, stride: int = 1) -> nn.Conv2d:
return nn.Conv2d(planes, planes, kernel_size=7, stride=stride, padding=3, groups=planes, bias=False)
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
norm = nn.BatchNorm2d
relu = lambda : nn.ReLU(inplace=True)
class Block(nn.Module):
expansion: int = 4
def __init__(
self,
inplanes: int,
width: int,
stride: int = 1,
projection: Optional[nn.Module] = None,
stodepth_survive: float = 1.
) -> None:
super().__init__()
self.relu = relu()
self.projection = projection
expanded = width * self.expansion
main_path = nn.Sequential(
dwconv7x7(inplanes, stride), norm(inplanes), relu(),
conv1x1(inplanes, expanded), norm(expanded), relu(),
conv1x1(expanded, width), norm(width),
)
self.main_path = StochasticDepth(main_path, stodepth_survive) if stodepth_survive < 1. else main_path
def forward(self, x: Tensor) -> Tensor:
out = self.main_path(x)
identity = x if self.projection is None else self.projection(x)
out = self.relu(out + identity)
return out
class ResNet(nn.Module):
def __init__(self, layers: List[int], num_classes: int = 1000, stodepth_survive: float = 1.) -> None:
super().__init__()
widths = [96, 192, 384, 768]
self.inplanes = widths[0]
self.stodepth = stodepth_survive
# Downsampling stem downsamples input size by 4, e.g. 224 -> 56
self.stem = nn.Sequential(
nn.Conv2d(3, self.inplanes, kernel_size=4, stride=4, bias=False),
norm(self.inplanes),
)
# Res1 -> Res4. No downsampling at the beginning of Res1.
self.stages = nn.Sequential(
*[self._make_stage(widths[i], layers[i], stride=2 if i != 0 else 1) for i in range(4)]
)
# Classification head
self.head = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(widths[-1], num_classes)
)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, m: nn.Module) -> None:
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
init_range = 1.0 / (m.out_features ** 0.5)
nn.init.uniform_(m.weight, -init_range, init_range)
nn.init.zeros_(m.bias)
def _make_stage(self, width : int, num_blocks: int, stride: int = 1) -> nn.Sequential:
blocks = []
# Projection where needed
projection = nn.Sequential(
conv1x1(self.inplanes, width, stride),
norm(width)
) if stride != 1 or self.inplanes != width else None
blocks.append(
Block(self.inplanes, width, stride=stride, projection=projection, stodepth_survive=self.stodepth)
)
# Remaining blocks of the stage
self.inplanes = width
for _ in range(1, num_blocks):
blocks.append(Block(self.inplanes, width, stride=1, projection=None))
return nn.Sequential(*blocks)
def forward(self, x: Tensor) -> Tensor:
x = self.stem(x)
x = self.stages(x)
x = self.head(x)
return x
def _resnet(layers: List[int], **kwargs: Any) -> ResNet:
model = ResNet(layers, **kwargs)
return model
def resnet50(**kwargs: Any) -> ResNet:
return _resnet([3, 3, 9, 3], **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment