Last active
October 14, 2019 07:23
-
-
Save kouyoumin/741083138b3ac8effe43bfdcdbe01086 to your computer and use it in GitHub Desktop.
Flexible DenseNet for pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.utils.checkpoint as cp | |
from collections import OrderedDict | |
from torchvision.models.utils import load_state_dict_from_url | |
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] | |
model_urls = { | |
'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', | |
'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', | |
'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', | |
'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', | |
} | |
def _bn_function_factory(norm, relu, conv): | |
def bn_function(*inputs): | |
concated_features = torch.cat(inputs, 1) | |
bottleneck_output = conv(relu(norm(concated_features))) | |
return bottleneck_output | |
return bn_function | |
class _DenseLayer(nn.Sequential): | |
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False, three_d=False): | |
super(_DenseLayer, self).__init__() | |
conv_layer = nn.Conv3d if three_d else nn.Conv2d | |
batchnorm_layer = nn.BatchNorm3d if three_d else nn.BatchNorm2d | |
self.dropout_layer = F.dropout3d if three_d else F.dropout2d | |
self.add_module('norm1', batchnorm_layer(num_input_features)), | |
self.add_module('relu1', nn.ReLU(inplace=True)), | |
self.add_module('conv1', conv_layer(num_input_features, bn_size * | |
growth_rate, kernel_size=1, stride=1, | |
bias=False)), | |
self.add_module('norm2', batchnorm_layer(bn_size * growth_rate)), | |
self.add_module('relu2', nn.ReLU(inplace=True)), | |
self.add_module('conv2', conv_layer(bn_size * growth_rate, growth_rate, | |
kernel_size=3, stride=1, padding=1, | |
bias=False)), | |
self.drop_rate = drop_rate | |
self.memory_efficient = memory_efficient | |
def forward(self, *prev_features): | |
bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) | |
if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features): | |
bottleneck_output = cp.checkpoint(bn_function, *prev_features) | |
else: | |
bottleneck_output = bn_function(*prev_features) | |
new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) | |
if self.drop_rate > 0: | |
new_features = self.dropout_layer(new_features, p=self.drop_rate, | |
training=self.training) | |
return new_features | |
class _DenseBlock(nn.Module): | |
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False, three_d=False): | |
super(_DenseBlock, self).__init__() | |
for i in range(num_layers): | |
layer = _DenseLayer( | |
num_input_features + i * growth_rate, | |
growth_rate=growth_rate, | |
bn_size=bn_size, | |
drop_rate=drop_rate, | |
memory_efficient=memory_efficient, | |
three_d=three_d | |
) | |
self.add_module('denselayer%d' % (i + 1), layer) | |
def forward(self, init_features): | |
features = [init_features] | |
for name, layer in self.named_children(): | |
new_features = layer(*features) | |
features.append(new_features) | |
return torch.cat(features, 1) | |
class _Transition(nn.Sequential): | |
def __init__(self, num_input_features, num_output_features, fdc=False, max_pool=False, three_d=False): | |
super(_Transition, self).__init__() | |
if not fdc: | |
conv_layer = nn.Conv3d if three_d else nn.Conv2d | |
batchnorm_layer = nn.BatchNorm3d if three_d else nn.BatchNorm2d | |
self.add_module('norm', batchnorm_layer(num_input_features)) | |
self.add_module('relu', nn.ReLU(inplace=True)) | |
self.add_module('conv', conv_layer(num_input_features, num_output_features, | |
kernel_size=1, stride=1, bias=False)) | |
if max_pool: | |
maxpool_layer = nn.MaxPool3d if three_d else nn.MaxPool2d | |
self.add_module('pool', maxpool_layer(kernel_size=2, stride=2)) | |
else: | |
avgpool_layer = nn.AvgPool3d if three_d else nn.AvgPool2d | |
self.add_module('pool', avgpool_layer(kernel_size=2, stride=2)) | |
class DenseNet(nn.Module): | |
r"""Densenet-BC model class, based on | |
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ | |
Args: | |
growth_rate (int) - how many filters to add each layer (`k` in paper) | |
block_config (list of 4 ints) - how many layers in each pooling block | |
input_channels (int) - 1 for grayscale, 3 for RGB (default: 3) | |
num_init_features (int) - the number of filters to learn in the first convolution layer | |
init_conv_kernel (int) - kernel size of first conv (default: 7) | |
init_conv_stride (int) - stride of first conv (default: 2) | |
init_pool (bool) - pool or not after first conv (default: True) | |
max_pool (bool) - use max pooling in transition layers (default: False - avg pooling) | |
bn_size (int) - multiplicative factor for number of bottle neck layers | |
(i.e. bn_size * k features in the bottleneck layer) | |
drop_rate (float) - dropout rate after each dense layer | |
num_classes (int) - number of classification classes | |
memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, | |
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_ | |
fdc (bool) - fully dense connectivity (no 1x1 conv in transition layer) | |
exp_grow (bool) - doubles growth rate as block grows (ex: grow_rate=32 for block 0, 64 for block 1, 128 for block 2...) | |
three_d (bool) - uses 3D conv, assumes (N,C,D,H,W) input | |
""" | |
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), | |
num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False, input_channels=3, init_conv_kernel=7, init_conv_stride=2, init_pool=True, max_pool=False, fdc=False, exp_grow=False, three_d=False): | |
super(DenseNet, self).__init__() | |
conv_layer = nn.Conv3d if three_d else nn.Conv2d | |
maxpool_layer = nn.MaxPool3d if three_d else nn.MaxPool2d | |
batchnorm_layer = nn.BatchNorm3d if three_d else nn.BatchNorm2d | |
self.globalpool_layer = F.adaptive_avg_pool3d if three_d else F.adaptive_avg_pool2d | |
# First convolution | |
self.features = nn.Sequential(OrderedDict([ | |
('conv0', conv_layer(input_channels, num_init_features, kernel_size=init_conv_kernel, stride=init_conv_stride, | |
padding=init_conv_kernel//2, bias=False)), | |
('norm0', batchnorm_layer(num_init_features)), | |
('relu0', nn.ReLU(inplace=True)), | |
])) | |
if init_pool: | |
self.features.add_module('pool0', maxpool_layer(kernel_size=3, stride=2, padding=1)) | |
# Each denseblock | |
num_features = num_init_features | |
for i, num_layers in enumerate(block_config): | |
block = _DenseBlock( | |
num_layers=num_layers, | |
num_input_features=num_features, | |
bn_size=bn_size, | |
growth_rate=(growth_rate * (2**i)) if exp_grow else growth_rate, | |
drop_rate=drop_rate, | |
memory_efficient=memory_efficient, | |
three_d=three_d | |
) | |
self.features.add_module('denseblock%d' % (i + 1), block) | |
num_features = (num_features + num_layers * growth_rate * (2**i)) if exp_grow else (num_features + num_layers * growth_rate) | |
if i != len(block_config) - 1: | |
trans = _Transition(num_input_features=num_features, | |
num_output_features=num_features // 2, | |
fdc=fdc, | |
max_pool=max_pool) | |
self.features.add_module('transition%d' % (i + 1), trans) | |
if not fdc: | |
num_features = num_features // 2 | |
# Final batch norm | |
self.features.add_module('norm5', batchnorm_layer(num_features)) | |
# Linear layer | |
self.classifier = conv_layer(num_features, num_classes, 1) | |
# Official init from torch repo. | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv3d): | |
nn.init.kaiming_normal_(m.weight) | |
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm3d): | |
nn.init.constant_(m.weight, 1) | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.Linear): | |
nn.init.constant_(m.bias, 0) | |
def forward(self, x, heatmap=False): | |
features = self.features(x) | |
if not heatmap: | |
out = F.relu(features, inplace=True) | |
out = self.globalpool_layer(out, 1) | |
out = self.classifier(out).flatten(1) | |
return out | |
else: | |
out = F.relu(features, inplace=True) | |
out = self.classifier(out) | |
return self.globalpool_layer(out, 1).flatten(1), out | |
def _load_state_dict(model, model_url, progress): | |
# '.'s are no longer allowed in module names, but previous _DenseLayer | |
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. | |
# They are also in the checkpoints in model_urls. This pattern is used | |
# to find such keys. | |
pattern = re.compile( | |
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') | |
state_dict = load_state_dict_from_url(model_url, progress=progress) | |
for key in list(state_dict.keys()): | |
res = pattern.match(key) | |
if res: | |
new_key = res.group(1) + res.group(2) | |
state_dict[new_key] = state_dict[key] | |
del state_dict[key] | |
state_dict['classifier.weight'].unsqueeze_(-1).unsqueeze_(-1) | |
model.load_state_dict(state_dict) | |
def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress, | |
**kwargs): | |
model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) | |
if pretrained: | |
_load_state_dict(model, model_urls[arch], progress) | |
return model | |
def densenet121(pretrained=False, progress=True, **kwargs): | |
r"""Densenet-121 model from | |
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ | |
Args: | |
pretrained (bool): If True, returns a model pre-trained on ImageNet | |
progress (bool): If True, displays a progress bar of the download to stderr | |
memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, | |
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_ | |
""" | |
return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress, | |
**kwargs) | |
def densenet161(pretrained=False, progress=True, **kwargs): | |
r"""Densenet-161 model from | |
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ | |
Args: | |
pretrained (bool): If True, returns a model pre-trained on ImageNet | |
progress (bool): If True, displays a progress bar of the download to stderr | |
memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, | |
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_ | |
""" | |
return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress, | |
**kwargs) | |
def densenet169(pretrained=False, progress=True, **kwargs): | |
r"""Densenet-169 model from | |
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ | |
Args: | |
pretrained (bool): If True, returns a model pre-trained on ImageNet | |
progress (bool): If True, displays a progress bar of the download to stderr | |
memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, | |
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_ | |
""" | |
return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress, | |
**kwargs) | |
def densenet201(pretrained=False, progress=True, **kwargs): | |
r"""Densenet-201 model from | |
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ | |
Args: | |
pretrained (bool): If True, returns a model pre-trained on ImageNet | |
progress (bool): If True, displays a progress bar of the download to stderr | |
memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, | |
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_ | |
""" | |
return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress, | |
**kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment