Skip to content

Instantly share code, notes, and snippets.

@kingsj0405
Forked from alper111/vgg_perceptual_loss.py
Last active February 3, 2021 05:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kingsj0405/9cfd2545ac4ffa54a791b1b55c25488b to your computer and use it in GitHub Desktop.
Save kingsj0405/9cfd2545ac4ffa54a791b1b55c25488b to your computer and use it in GitHub Desktop.
PyTorch implementation of VGG perceptual loss
# Forked from vgg perceptual loss
# https://gist.github.com/alper111/8233cdb0414b4cb5853f2f730ab95a49
import torch
import torchvision
class VGGFeatureExtractor(torch.nn.Module):
def __init__(self, resize=True):
super(VGGFeatureExtractor, self).__init__()
blocks = []
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
for bl in blocks:
for p in bl:
p.requires_grad = False
self.blocks = torch.nn.ModuleList(blocks)
self.transform = torch.nn.functional.interpolate
self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
self.resize = resize
def forward(self, x):
if x.shape[1] != 3:
x = x.repeat(1, 3, 1, 1)
x = (x-self.mean) / self.std
if self.resize:
x = self.transform(x, mode='bilinear', size=(224, 224), align_corners=False)
for block in self.blocks:
x = block(x)
return x
if __name__ == '__main__':
x = torch.randn((244, 244))
feature_extractor = VGGFeatureExtractor()
print(feature_extractor)
feature = feature_extractor(x)
print(feature.shape)
# VGGFeatureExtractor(
# (blocks): ModuleList(
# (0): Sequential(
# (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (1): ReLU(inplace=True)
# (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (3): ReLU(inplace=True)
# )
# (1): Sequential(
# (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (6): ReLU(inplace=True)
# (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (8): ReLU(inplace=True)
# )
# (2): Sequential(
# (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (11): ReLU(inplace=True)
# (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (13): ReLU(inplace=True)
# (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (15): ReLU(inplace=True)
# )
# (3): Sequential(
# (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (18): ReLU(inplace=True)
# (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (20): ReLU(inplace=True)
# (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (22): ReLU(inplace=True)
# )
# )
# )
# torch.Size([1, 512, 28, 28])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment