Skip to content

Instantly share code, notes, and snippets.

@c0nn3r
Last active November 4, 2017 16:33
Show Gist options
  • Save c0nn3r/d0fcb7edb57c405cb707afeca28a8329 to your computer and use it in GitHub Desktop.
Save c0nn3r/d0fcb7edb57c405cb707afeca28a8329 to your computer and use it in GitHub Desktop.
import torch.utils.model_zoo as model_zoo
from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
class BasicBlockFeatures(BasicBlock):
def forward(self, x):
if isinstance(x, tuple):
x = x[0]
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
conv2_rep = out
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out, conv2_rep
class BottleneckFeatures(Bottleneck):
'''
A Bottleneck that returns its last conv layer features.
'''
def forward(self, x):
if isinstance(x, tuple):
x = x[0]
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
conv3_rep = out
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out, conv3_rep
class ResNetFeatures(ResNet):
'''
A ResNet that returns features instead of classification.
'''
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x, c2 = self.layer1(x)
x, c3 = self.layer2(x)
x, c4 = self.layer3(x)
x, c5 = self.layer4(x)
return c2, c3, c4, c5
def resnet50_features(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNetFeatures(BottleneckFeatures, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model
def resnet101_features(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNetFeatures(BottleneckFeatures, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment