Skip to content

Instantly share code, notes, and snippets.

@crcrpar
Created September 14, 2017 02:23
Show Gist options
  • Star 9 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save crcrpar/a5d46738ffff08fc12138a5f270db426 to your computer and use it in GitHub Desktop.
Save crcrpar/a5d46738ffff08fc12138a5f270db426 to your computer and use it in GitHub Desktop.
[PyTorch] pre-trained VGG16 for perceptual loss. e.g. Style Transfer
"""Modified VGG16 to compute perceptual loss.
This class is mostly copied from pytorch/examples.
See, fast_neural_style in https://github.com/pytorch/examples.
"""
import torch
from torchvision import models
class VGG_OUTPUT(object):
def __init__(self, relu1_2, relu2_2, relu3_3, relu4_3):
self.__dict__ = locals()
class VGG16(torch.nn.Module):
def __init__(self, requires_grad=False):
super(VGG16, self).__init__()
vgg_pretrained_features = models.vgg16(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
return VGG_OUTPUT(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment