This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Resnet + SVM""" | |
import torch | |
from torch import nn | |
import torchvision.transforms as T | |
from torchvision import models | |
class SVM(nn.Module): | |
"""Multi-Class SVM with Gaussian Kernel (Radial Basis Function) | |
Source: https://github.com/JeremyLinux/PyTorch-Radial-Basis-Function-Layer | |
""" | |
def __init__(self, in_features, num_centers, num_classes): | |
super().__init__() | |
self.centres = nn.Parameter(torch.randn(num_centers, in_features)) | |
self.neg_log_sigmas = nn.Parameter(torch.zeros(num_centers)) | |
self.classifier = nn.Linear(num_centers, num_classes) | |
def forward(self, inputs): | |
"""forward: input is recommended to have zero mean and unit variance""" | |
norms = (inputs.unsqueeze(-2) - self.centres).norm(p=2, dim=-1) | |
log_radial_bases = -(self.neg_log_sigmas.exp() * norms)**2 | |
return self.classifier(log_radial_bases.exp()) | |
def fuse_normalize_conv(norm, conv): | |
"""Fuse a `torchvision.transforms.Normalize` into a `torch.nn.Conv2d`""" | |
weight = conv.weight.detach() | |
factory_kwargs = dict(dtype=weight.dtype, device=weight.device) | |
as_tensor = lambda x: torch.as_tensor(x, **factory_kwargs) | |
weight /= as_tensor(norm.std).view(-1, 1, 1) | |
mean = as_tensor(norm.mean).unsqueeze(-1) | |
bias = (weight.sum([-2, -1]) @ -mean).squeeze(-1) | |
if conv.bias is None: | |
conv.bias = nn.Parameter(bias) | |
else: | |
conv.bias.detach().add_(bias) | |
def get_resnet( | |
num_layers=18, | |
svm_centers=0, | |
num_classes=1000, | |
pretrained=False, | |
normalize=False, | |
grayscale=False, | |
**kwargs, | |
): | |
"""Get a modified ResNet model (supports SVM classifier head) | |
Args: | |
num_layers: interger that should be in {18, 34, 50, 101, 152} | |
svm_centers: use SVM if greater than zero (e.g., 4096) | |
num_classes: number of output labels for the classifier head | |
pretrained: whether to use ImageNet pretrained weights | |
normalize: whether the input is expected to be in the range [0, 1] | |
grayscale: whether the input has only one channel | |
kwargs: keyword arguments to pass to `torchvision.models.resnet{}` | |
Returns: | |
The modified resnet classifier | |
""" | |
model = getattr(models, f'resnet{num_layers}')(pretrained, **kwargs) | |
if normalize and pretrained: | |
mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) | |
fuse_normalize_conv(T.Normalize(mean, std), model.conv1) | |
bias, model.conv1.bias = model.conv1.bias.detach(), None | |
model.bn1.running_mean -= bias | |
if grayscale: | |
conv = nn.Conv2d(1, 64, 7, 2, 3, bias=False) | |
if pretrained: | |
weight = model.conv1.weight.detach().sum(1, keepdim=True) | |
conv.weight.detach().copy_(weight) | |
model.conv1 = conv | |
if svm_centers > 0: | |
model.fc = SVM(2048, svm_centers, num_classes) | |
elif num_classes != 1000: | |
model.fc = nn.Linear(2048, num_classes) | |
nn.init.normal_(model.fc.weight, mean=0.0, std=0.01) | |
nn.init.zeros_(model.fc.bias) | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment