Skip to content

Instantly share code, notes, and snippets.

@xmodar
Created February 10, 2022 20:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xmodar/0249afbefdc3e4b55ca780de8c72d5b9 to your computer and use it in GitHub Desktop.
Save xmodar/0249afbefdc3e4b55ca780de8c72d5b9 to your computer and use it in GitHub Desktop.
"""Resnet + SVM"""
import torch
from torch import nn
import torchvision.transforms as T
from torchvision import models
class SVM(nn.Module):
"""Multi-Class SVM with Gaussian Kernel (Radial Basis Function)
Source: https://github.com/JeremyLinux/PyTorch-Radial-Basis-Function-Layer
"""
def __init__(self, in_features, num_centers, num_classes):
super().__init__()
self.centres = nn.Parameter(torch.randn(num_centers, in_features))
self.neg_log_sigmas = nn.Parameter(torch.zeros(num_centers))
self.classifier = nn.Linear(num_centers, num_classes)
def forward(self, inputs):
"""forward: input is recommended to have zero mean and unit variance"""
norms = (inputs.unsqueeze(-2) - self.centres).norm(p=2, dim=-1)
log_radial_bases = -(self.neg_log_sigmas.exp() * norms)**2
return self.classifier(log_radial_bases.exp())
def fuse_normalize_conv(norm, conv):
"""Fuse a `torchvision.transforms.Normalize` into a `torch.nn.Conv2d`"""
weight = conv.weight.detach()
factory_kwargs = dict(dtype=weight.dtype, device=weight.device)
as_tensor = lambda x: torch.as_tensor(x, **factory_kwargs)
weight /= as_tensor(norm.std).view(-1, 1, 1)
mean = as_tensor(norm.mean).unsqueeze(-1)
bias = (weight.sum([-2, -1]) @ -mean).squeeze(-1)
if conv.bias is None:
conv.bias = nn.Parameter(bias)
else:
conv.bias.detach().add_(bias)
def get_resnet(
num_layers=18,
svm_centers=0,
num_classes=1000,
pretrained=False,
normalize=False,
grayscale=False,
**kwargs,
):
"""Get a modified ResNet model (supports SVM classifier head)
Args:
num_layers: interger that should be in {18, 34, 50, 101, 152}
svm_centers: use SVM if greater than zero (e.g., 4096)
num_classes: number of output labels for the classifier head
pretrained: whether to use ImageNet pretrained weights
normalize: whether the input is expected to be in the range [0, 1]
grayscale: whether the input has only one channel
kwargs: keyword arguments to pass to `torchvision.models.resnet{}`
Returns:
The modified resnet classifier
"""
model = getattr(models, f'resnet{num_layers}')(pretrained, **kwargs)
if normalize and pretrained:
mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
fuse_normalize_conv(T.Normalize(mean, std), model.conv1)
bias, model.conv1.bias = model.conv1.bias.detach(), None
model.bn1.running_mean -= bias
if grayscale:
conv = nn.Conv2d(1, 64, 7, 2, 3, bias=False)
if pretrained:
weight = model.conv1.weight.detach().sum(1, keepdim=True)
conv.weight.detach().copy_(weight)
model.conv1 = conv
if svm_centers > 0:
model.fc = SVM(2048, svm_centers, num_classes)
elif num_classes != 1000:
model.fc = nn.Linear(2048, num_classes)
nn.init.normal_(model.fc.weight, mean=0.0, std=0.01)
nn.init.zeros_(model.fc.bias)
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment