Created
May 1, 2023 18:52
-
-
Save VSehwag/f622e777e78ded5b516a62a75247e6f3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" ConvNeXt | |
Paper: `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf | |
Original code and weights from https://github.com/facebookresearch/ConvNeXt, original copyright below | |
Model defs atto, femto, pico, nano and _ols / _hnf variants are timm specific. | |
Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman | |
Adapted to work with 32x32 resolution images by Vikash Sehwag | |
Requirements: timm | |
""" | |
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the MIT license | |
from collections import OrderedDict | |
from functools import partial | |
import torch | |
import torch.nn as nn | |
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
from timm.models.helpers import named_apply, build_model_with_cfg, checkpoint_seq | |
from timm.models.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, \ | |
create_conv2d, get_act_layer, make_divisible, to_ntuple | |
from timm.models.registry import register_model | |
# LayerNorm is not available in current commit of timm, so importing it from latest timm version | |
class LayerNorm(nn.LayerNorm): | |
""" LayerNorm w/ fast norm option | |
""" | |
def __init__(self, num_channels, eps=1e-6, affine=True): | |
super().__init__(num_channels, eps=eps, elementwise_affine=affine) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) | |
__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this | |
def _cfg(url='', **kwargs): | |
return { | |
'url': url, | |
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), | |
'crop_pct': 0.875, 'interpolation': 'bicubic', | |
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, | |
'first_conv': 'stem.0', 'classifier': 'head.fc', | |
**kwargs | |
} | |
default_cfgs = { | |
'convnext_tiny32': _cfg(), | |
'convnext_small32': _cfg(), | |
'convnext_base32': _cfg(), | |
'convnext_large32': _cfg(), | |
'convnext_xlarge32': _cfg(), | |
} | |
class ConvNeXtBlock(nn.Module): | |
""" ConvNeXt Block | |
There are two equivalent implementations: | |
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) | |
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back | |
Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate | |
choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear | |
is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW. | |
Args: | |
in_chs (int): Number of input channels. | |
drop_path (float): Stochastic depth rate. Default: 0.0 | |
ls_init_value (float): Init value for Layer Scale. Default: 1e-6. | |
""" | |
def __init__( | |
self, | |
in_chs, | |
out_chs=None, | |
kernel_size=7, | |
stride=1, | |
dilation=1, | |
mlp_ratio=4, | |
conv_mlp=False, | |
conv_bias=True, | |
ls_init_value=1e-6, | |
act_layer='gelu', | |
norm_layer=None, | |
drop_path=0., | |
): | |
super().__init__() | |
out_chs = out_chs or in_chs | |
act_layer = get_act_layer(act_layer) | |
if not norm_layer: | |
norm_layer = LayerNorm2d if conv_mlp else LayerNorm | |
mlp_layer = ConvMlp if conv_mlp else Mlp | |
self.use_conv_mlp = conv_mlp | |
self.conv_dw = create_conv2d( | |
in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation, depthwise=True, bias=conv_bias) | |
self.norm = norm_layer(out_chs) | |
self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer) | |
self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value > 0 else None | |
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
def forward(self, x): | |
shortcut = x | |
x = self.conv_dw(x) | |
if self.use_conv_mlp: | |
x = self.norm(x) | |
x = self.mlp(x) | |
else: | |
x = x.permute(0, 2, 3, 1) | |
x = self.norm(x) | |
x = self.mlp(x) | |
x = x.permute(0, 3, 1, 2) | |
if self.gamma is not None: | |
x = x.mul(self.gamma.reshape(1, -1, 1, 1)) | |
x = self.drop_path(x) + shortcut | |
return x | |
class ConvNeXtStage(nn.Module): | |
def __init__( | |
self, | |
in_chs, | |
out_chs, | |
kernel_size=7, | |
stride=2, | |
depth=2, | |
dilation=(1, 1), | |
drop_path_rates=None, | |
ls_init_value=1.0, | |
conv_mlp=False, | |
conv_bias=True, | |
act_layer='gelu', | |
norm_layer=None, | |
norm_layer_cl=None | |
): | |
super().__init__() | |
self.grad_checkpointing = False | |
if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]: | |
ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1 | |
pad = 'same' if dilation[1] > 1 else 0 # same padding needed if dilation used | |
self.downsample = nn.Sequential( | |
norm_layer(in_chs), | |
create_conv2d( | |
in_chs, out_chs, kernel_size=ds_ks, stride=stride, | |
dilation=dilation[0], padding=pad, bias=conv_bias), | |
) | |
in_chs = out_chs | |
else: | |
self.downsample = nn.Identity() | |
drop_path_rates = drop_path_rates or [0.] * depth | |
stage_blocks = [] | |
for i in range(depth): | |
stage_blocks.append(ConvNeXtBlock( | |
in_chs=in_chs, | |
out_chs=out_chs, | |
kernel_size=kernel_size, | |
dilation=dilation[1], | |
drop_path=drop_path_rates[i], | |
ls_init_value=ls_init_value, | |
conv_mlp=conv_mlp, | |
conv_bias=conv_bias, | |
act_layer=act_layer, | |
norm_layer=norm_layer if conv_mlp else norm_layer_cl | |
)) | |
in_chs = out_chs | |
self.blocks = nn.Sequential(*stage_blocks) | |
def forward(self, x): | |
x = self.downsample(x) | |
if self.grad_checkpointing and not torch.jit.is_scripting(): | |
x = checkpoint_seq(self.blocks, x) | |
else: | |
x = self.blocks(x) | |
return x | |
class ConvNeXt(nn.Module): | |
r""" ConvNeXt | |
A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf | |
Args: | |
in_chans (int): Number of input image channels. Default: 3 | |
num_classes (int): Number of classes for classification head. Default: 1000 | |
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] | |
dims (tuple(int)): Feature dimension at each stage. Default: [96, 192, 384, 768] | |
drop_rate (float): Head dropout rate | |
drop_path_rate (float): Stochastic depth rate. Default: 0. | |
ls_init_value (float): Init value for Layer Scale. Default: 1e-6. | |
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. | |
""" | |
def __init__( | |
self, | |
in_chans=3, | |
num_classes=1000, | |
global_pool='avg', | |
output_stride=32, | |
depths=(3, 3, 9, 3), | |
dims=(96, 192, 384, 768), | |
kernel_sizes=7, | |
ls_init_value=1e-6, | |
stem_type='patch', | |
patch_size=4, | |
head_init_scale=1., | |
head_norm_first=False, | |
conv_mlp=False, | |
conv_bias=True, | |
act_layer='gelu', | |
norm_layer=None, | |
drop_rate=0., | |
drop_path_rate=0., | |
): | |
super().__init__() | |
assert output_stride in (8, 16, 32) | |
kernel_sizes = to_ntuple(4)(kernel_sizes) | |
if norm_layer is None: | |
norm_layer = LayerNorm2d | |
norm_layer_cl = norm_layer if conv_mlp else LayerNorm | |
else: | |
assert conv_mlp,\ | |
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input' | |
norm_layer_cl = norm_layer | |
self.num_classes = num_classes | |
self.drop_rate = drop_rate | |
self.feature_info = [] | |
assert stem_type in ('patch', 'overlap', 'overlap_tiered', 'fixed_res') | |
if stem_type == 'patch': | |
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 | |
self.stem = nn.Sequential( | |
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias), | |
norm_layer(dims[0]) | |
) | |
stem_stride = patch_size | |
elif stem_type == 'fixed_res': | |
mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0] | |
self.stem = nn.Sequential( | |
nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=1, padding=1, bias=conv_bias), | |
nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=1, padding=1, bias=conv_bias), | |
norm_layer(dims[0]), | |
) | |
stem_stride = 4 | |
else: | |
mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0] | |
self.stem = nn.Sequential( | |
nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias), | |
nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias), | |
norm_layer(dims[0]), | |
) | |
stem_stride = 4 | |
self.stages = nn.Sequential() | |
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] | |
stages = [] | |
prev_chs = dims[0] | |
curr_stride = stem_stride | |
dilation = 1 | |
# 4 feature resolution stages, each consisting of multiple residual blocks | |
for i in range(4): | |
stride = 2 if curr_stride == 2 or i > 0 else 1 | |
if curr_stride >= output_stride and stride > 1: | |
dilation *= stride | |
stride = 1 | |
curr_stride *= stride | |
first_dilation = 1 if dilation in (1, 2) else 2 | |
out_chs = dims[i] | |
stages.append(ConvNeXtStage( | |
prev_chs, | |
out_chs, | |
kernel_size=kernel_sizes[i], | |
stride=stride, | |
dilation=(first_dilation, dilation), | |
depth=depths[i], | |
drop_path_rates=dp_rates[i], | |
ls_init_value=ls_init_value, | |
conv_mlp=conv_mlp, | |
conv_bias=conv_bias, | |
act_layer=act_layer, | |
norm_layer=norm_layer, | |
norm_layer_cl=norm_layer_cl | |
)) | |
prev_chs = out_chs | |
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 | |
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] | |
self.stages = nn.Sequential(*stages) | |
self.num_features = prev_chs | |
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets | |
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights) | |
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity() | |
self.head = nn.Sequential(OrderedDict([ | |
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), | |
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)), | |
('flatten', nn.Flatten(1) if global_pool else nn.Identity()), | |
('drop', nn.Dropout(self.drop_rate)), | |
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())])) | |
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) | |
@torch.jit.ignore | |
def group_matcher(self, coarse=False): | |
return dict( | |
stem=r'^stem', | |
blocks=r'^stages\.(\d+)' if coarse else [ | |
(r'^stages\.(\d+)\.downsample', (0,)), # blocks | |
(r'^stages\.(\d+)\.blocks\.(\d+)', None), | |
(r'^norm_pre', (99999,)) | |
] | |
) | |
@torch.jit.ignore | |
def set_grad_checkpointing(self, enable=True): | |
for s in self.stages: | |
s.grad_checkpointing = enable | |
@torch.jit.ignore | |
def get_classifier(self): | |
return self.head.fc | |
def reset_classifier(self, num_classes=0, global_pool=None): | |
if global_pool is not None: | |
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) | |
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() | |
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() | |
def forward_features(self, x): | |
x = self.stem(x) | |
x = self.stages(x) | |
x = self.norm_pre(x) | |
return x | |
def forward_head(self, x, pre_logits: bool = False): | |
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :( | |
x = self.head.global_pool(x) | |
x = self.head.norm(x) | |
x = self.head.flatten(x) | |
x = self.head.drop(x) | |
return x if pre_logits else self.head.fc(x) | |
def forward(self, x): | |
x = self.forward_features(x) | |
x = self.forward_head(x) | |
return x | |
def _init_weights(module, name=None, head_init_scale=1.0): | |
if isinstance(module, nn.Conv2d): | |
trunc_normal_(module.weight, std=.02) | |
if module.bias is not None: | |
nn.init.zeros_(module.bias) | |
elif isinstance(module, nn.Linear): | |
trunc_normal_(module.weight, std=.02) | |
nn.init.zeros_(module.bias) | |
if name and 'head.' in name: | |
module.weight.data.mul_(head_init_scale) | |
module.bias.data.mul_(head_init_scale) | |
def checkpoint_filter_fn(state_dict, model): | |
""" Remap FB checkpoints -> timm """ | |
if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict: | |
return state_dict # non-FB checkpoint | |
if 'model' in state_dict: | |
state_dict = state_dict['model'] | |
out_dict = {} | |
import re | |
for k, v in state_dict.items(): | |
k = k.replace('downsample_layers.0.', 'stem.') | |
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k) | |
k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k) | |
k = k.replace('dwconv', 'conv_dw') | |
k = k.replace('pwconv', 'mlp.fc') | |
k = k.replace('head.', 'head.fc.') | |
if k.startswith('norm.'): | |
k = k.replace('norm', 'head.norm') | |
if v.ndim == 2 and 'head' not in k: | |
model_shape = model.state_dict()[k].shape | |
v = v.reshape(model_shape) | |
out_dict[k] = v | |
return out_dict | |
def _create_convnext(variant, pretrained=False, **kwargs): | |
model = build_model_with_cfg( | |
ConvNeXt, variant, pretrained, | |
pretrained_filter_fn=checkpoint_filter_fn, | |
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), | |
**kwargs) | |
return model | |
@register_model | |
def convnext_tiny32(pretrained=False, **kwargs): | |
# experimental tiny variant with head that does't downsample | |
model_args = dict( | |
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), stem_type='fixed_res', **kwargs) | |
model = _create_convnext('convnext_tiny32', pretrained=pretrained, kernel_sizes=3, **model_args) | |
return model | |
@register_model | |
def convnext_small32(pretrained=False, **kwargs): | |
# experimental tiny variant with head that does't downsample | |
model_args = dict( | |
depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], stem_type='fixed_res', **kwargs) | |
model = _create_convnext('convnext_small32', pretrained=pretrained, kernel_sizes=3, **model_args) | |
return model | |
@register_model | |
def convnext_base32(pretrained=False, **kwargs): | |
# experimental tiny variant with head that does't downsample | |
model_args = dict( | |
depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], stem_type='fixed_res', **kwargs) | |
model = _create_convnext('convnext_base32', pretrained=pretrained, kernel_sizes=3, **model_args) | |
return model | |
@register_model | |
def convnext_large32(pretrained=False, **kwargs): | |
# experimental tiny variant with head that does't downsample | |
model_args = dict( | |
depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], stem_type='fixed_res', **kwargs) | |
model = _create_convnext('convnext_large32', pretrained=pretrained, kernel_sizes=3, **model_args) | |
return model | |
@register_model | |
def convnext_xlarge32(pretrained=False, **kwargs): | |
# experimental tiny variant with head that does't downsample | |
model_args = dict( | |
depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], stem_type='fixed_res', **kwargs) | |
model = _create_convnext('convnext_xlarge32', pretrained=pretrained, kernel_sizes=3, **model_args) | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment