Skip to content

Instantly share code, notes, and snippets.

@alper111
Last active April 10, 2024 02:21
Show Gist options
  • Star 111 You must be signed in to star a gist
  • Fork 6 You must be signed in to fork a gist
  • Save alper111/8233cdb0414b4cb5853f2f730ab95a49 to your computer and use it in GitHub Desktop.
Save alper111/8233cdb0414b4cb5853f2f730ab95a49 to your computer and use it in GitHub Desktop.
PyTorch implementation of VGG perceptual loss
import torch
import torchvision
class VGGPerceptualLoss(torch.nn.Module):
def __init__(self, resize=True):
super(VGGPerceptualLoss, self).__init__()
blocks = []
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
for bl in blocks:
for p in bl.parameters():
p.requires_grad = False
self.blocks = torch.nn.ModuleList(blocks)
self.transform = torch.nn.functional.interpolate
self.resize = resize
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]):
if input.shape[1] != 3:
input = input.repeat(1, 3, 1, 1)
target = target.repeat(1, 3, 1, 1)
input = (input-self.mean) / self.std
target = (target-self.mean) / self.std
if self.resize:
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
loss = 0.0
x = input
y = target
for i, block in enumerate(self.blocks):
x = block(x)
y = block(y)
if i in feature_layers:
loss += torch.nn.functional.l1_loss(x, y)
if i in style_layers:
act_x = x.reshape(x.shape[0], x.shape[1], -1)
act_y = y.reshape(y.shape[0], y.shape[1], -1)
gram_x = act_x @ act_x.permute(0, 2, 1)
gram_y = act_y @ act_y.permute(0, 2, 1)
loss += torch.nn.functional.l1_loss(gram_x, gram_y)
return loss
@alper111
Copy link
Author

alper111 commented Feb 21, 2021

Hello everyone. I have added an optional gram matrix computation to find Equation 4 in Johnson et al. 2016 "Perceptual Losses for Real-Time Style Transfer and Super-Resolution". The previous version was only computing Equation 2 (i.e. feature reconstruction loss).

For the previous version:

vgg = VGGPerceptualLoss()
vgg(img1, img2, feature_layers=[0, 1, 2, 3], style_layers=[])
# or just vgg(img1, img2)

For Johnson et al. 2016:

vgg = VGGPerceptualLoss()
vgg(img1, img2, feature_layers=[2], style_layers=[0, 1, 2, 3])

@MohitLamba94
Copy link

@alper111 @MohitLamba94 Parameters are used for trainable tensors, for the tensors that need to stay constant register_buffer is preferred. Something like self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))

Yes, I think this is more sensible. Thank you for pointing it out.

@MohitLamba94
Copy link

@alper111 @MohitLamba94 Parameters are used for trainable tensors, for the tensors that need to stay constant register_buffer is preferred. Something like self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))

Yes, you are correct. But functionally the author does not seems to be wrong. This is because this class, VGGPerceptualLoss will not be a part of the optimizer in a training setup and thus mean and std will remain the same after backpropagation. Thus for this case, the author's solution and your modification seem to be equivalent. But certainly, it would be good to code the way you have suggested.

@machineko
Copy link

Hey you should change

    def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]):
    def forward(self, input, target, feature_layers=(0, 1, 2, 3), style_layers=()):

or

    def forward(self, input, target, feature_layers=(0, 1, 2, 3), style_layers=None):

As new list is created once when the function is defined, and the same list is reused every time.

📦

@alper111
Copy link
Author

@alper111 @MohitLamba94 Parameters are used for trainable tensors, for the tensors that need to stay constant register_buffer is preferred. Something like self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))

Thank you @bobiblazeski for pointing out this. I have changed it.

@tobias-kirschstein
Copy link

This line is completely wrong:

for bl in blocks:
    for p in bl:
        p.requires_grad = False

You are introducing a requires_grad attribute on each module instead of the actual parameters which does nothing.
See the fix of @brucemuller above:
https://gist.github.com/alper111/8233cdb0414b4cb5853f2f730ab95a49#gistcomment-3347450

@Pixie8888
Copy link

Pixie8888 commented Aug 19, 2021

Hi, do you need to add "with torch.no_grad()" before computing vgg feature? I think it can reduce memory usage.
I use your code to compute perceptual loss. The training time is much slower and batch size is much smaller compared to training without perceptual loss. @alper111

@MohitLamba94
Copy link

Hi, do you need to add "with torch.no_grad()" before computing vgg feature? I think it can reduce memory usage.
I use your code to compute perceptual loss. The training time is much slower and batch size is much smaller compared to training without perceptual loss. @alper111

Hi,
I think doing this will be a big blunder. My understanding of with torch.no_grad() is that it completely switches off the autograd mechanism. If this is true, and it is used in forward pass of VGG perceptual loss, what for are you computing the loss? The purpose behind computing loss is to get the gradients to update model parameters. @alper111 any comments?

@alper111
Copy link
Author

alper111 commented Sep 28, 2021

This line is completely wrong:

for bl in blocks:
    for p in bl:
        p.requires_grad = False

You are introducing a requires_grad attribute on each module instead of the actual parameters which does nothing. See the fix of @brucemuller above: https://gist.github.com/alper111/8233cdb0414b4cb5853f2f730ab95a49#gistcomment-3347450

Yes, you are right. I somehow missed this one, thanks for pointing it out. Sorry for fixing it a bit late.

@alper111
Copy link
Author

alper111 commented Sep 28, 2021

Hi, do you need to add "with torch.no_grad()" before computing vgg feature? I think it can reduce memory usage.
I use your code to compute perceptual loss. The training time is much slower and batch size is much smaller compared to training without perceptual loss. @alper111

Hi, I think doing this will be a big blunder. My understanding of with torch.no_grad() is that it completely switches off the autograd mechanism. If this is true, and it is used in forward pass of VGG perceptual loss, what for are you computing the loss? The purpose behind computing loss is to get the gradients to update model parameters. @alper111 any comments?

If you use with torch.no_grad() then you disallow any possible back-propagation from the perceptual loss. For that reason, I only disabled the gradient computation for VGG parameters (and actually fixed a blunder thanks @brucemuller and @tobias-kirschstein for pointing it out). Of course, you can enable the gradient computation for VGG parameters for your specific application, if necessary.

@israrbacha
Copy link

Hi, can we append all the required feature layers in one line like: block.append(vgg.features[4:23])? whats the reason to append it in chunks? i.e. features[:4], [4:9], [9:16]............?

@alper111
Copy link
Author

I wanted to extract features from those specific blocks to calculate the perceptual loss, therefore appended them in chunks. We can also append them in one line as you have suggested. But then in the forward loop, if you want to get activations from those layers (4, 9, 16, ...), you would need to slice that block in the loop with an if statement and so on. It depends on what you want to do I guess.

@aegonwolf
Copy link

Hi there, thanks so much for your implementation, it's really clean and easy to understand and I was able to implement it well in my project.
This is a really long shot, would you know what type of features these blocks contain? I.e. which are shapes and which are colors/style filters?
I think the first one is shapes, which I figured by experimentation, with the others it's not so clear.

@alper111
Copy link
Author

Hi there, I am happy that it is useful for your project. Well, I am not sure if these blocks necessarily specialize in colors/style etc, but people think so based on experimentation. You can actually find more information and experiments about those layers in https://arxiv.org/abs/1603.08155. In short, they think that earlier layers of VGG-16 contain style, and layers to the end contain the content (see Eq. 5 in the paper). Though, I don't know if specific channels/layers contain more specific info such as colors, lines, and so on.

@woolee98
Copy link

hi, very nice work. I have a naive question: in lines 8-11, what is the meaning of ..features[:4], [4:9], [9:16], [16:23]? Does that mean there are 24 features in total? Thanks.

@alper111
Copy link
Author

alper111 commented Dec 19, 2021

Hi @woolee98. features contain the layers of the VGG network (maybe an unfortunate naming by me). features[:4], features[4:9], ... merely correspond different blocks of layers of the VGG network. These are the specific blocks of layers that are used in https://arxiv.org/abs/1603.08155 for style and content transfer. You can check Fig. 2 in this paper, that would probably make sense.

By the way, although there are 24 "pytorch layers" in this network, some of them are just ReLU activations. There should be 16 convolutional layers in this network if I remember correctly (as the name suggests).

@zhengwjie
Copy link

Hello @alper111, I am using your perceptual loss when training a model, my code and model is using gpu, but your loss is written to use in a cpu, I wondering what modification should I do to use it in my model using gpu

@alper111
Copy link
Author

Hi @zhengwjie. Since this is implemented as a torch.nn.Module, you can initialize the loss module and move it to the corresponding gpu:

vgg_loss = VGGPerceptualLoss()
vgg_loss.to("cuda:0")  # or cuda:1, cuda:2 ...

I haven't tried this at the moment, but it should work because I was using this module to train a model in GPU.

@sheyining
Copy link

Thanks for your work. In the original paper https://arxiv.org/abs/1603.08155), they used l2 loss for the "Feature Reconstruction Loss", and use the squared Frobenius norm for "Style Reconstruction Loss". But you are using l1_loss for both loss computations. Could you please explain why you use l1_loss? Shouldn't they be fixed?

@JacopoMangiavacchi
Copy link

Thanks for your work. I've just added the capacity to weight the layers and documented usage of this loss on a style transfer scenario: https://medium.com/@JMangia/optimize-a-face-to-cartoon-style-transfer-model-trained-quickly-on-small-style-dataset-and-50594126e792

@siarheidevel
Copy link

One question:

gram_x = act_x @ act_x.permute(0, 2, 1)
gram_y = act_y @ act_y.permute(0, 2, 1)

Maybe you need to normalize gram matrices by dividing by number of elements:

b,c,h,w = x.shape
gram_x = act_x @ act_x.permute(0, 2, 1) / (c*h*w)
gram_y = act_y @ act_y.permute(0, 2, 1) / (c*h*w)

@alex-vasilchenko-md
Copy link

alex-vasilchenko-md commented Jun 1, 2022

Thanks for your nice implementation!

I refactored it a little bit while I was reviewing how it works:

https://gist.github.com/alex-vasilchenko-md/dc5155f96f73fc4f67afffcb74f635e0

@alex-vasilchenko-md
Copy link

Hi @zhengwjie. Since this is implemented as a torch.nn.Module, you can initialize the loss module and move it to the corresponding gpu:

vgg_loss = VGGPerceptualLoss()
vgg_loss.to("cuda:0")  # or cuda:1, cuda:2 ...

I haven't tried this at the moment, but it should work because I was using this module to train a model in GPU.

it worked for me when I trained my model on GPU. without this I had an issue like this:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

@alper111
Copy link
Author

Thanks for your work. In the original paper https://arxiv.org/abs/1603.08155), they used l2 loss for the "Feature Reconstruction Loss", and use the squared Frobenius norm for "Style Reconstruction Loss". But you are using l1_loss for both loss computations. Could you please explain why you use l1_loss? Shouldn't they be fixed?

Thanks for the interest @sheyining. On my specific application, L1 was working better. Other than that, I have no specific motivation to choose L1 over L2.

Thanks for your work. I've just added the capacity to weight the layers and documented usage of this loss on a style transfer scenario: https://medium.com/@JMangia/optimize-a-face-to-cartoon-style-transfer-model-trained-quickly-on-small-style-dataset-and-50594126e792

Thanks for the interest. A good blog post!

One question:

gram_x = act_x @ act_x.permute(0, 2, 1)
gram_y = act_y @ act_y.permute(0, 2, 1)

Maybe you need to normalize gram matrices by dividing by number of elements:

b,c,h,w = x.shape
gram_x = act_x @ act_x.permute(0, 2, 1) / (c*h*w)
gram_y = act_y @ act_y.permute(0, 2, 1) / (c*h*w)

@siarheidevel Indeed, we can normalize them. Again, on my specific application, it was better not to normalize it.

@uroojz
Copy link

uroojz commented Jan 4, 2023

def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]):
if input.shape[1] != 3:
input = input.repeat(1, 3, 1, 1)
target = target.repeat(1, 3, 1, 1)
hi there,
is it necessary to check this condition if we have gray scale image i.e 1 channel,I am a beginner so got little bit confusion when implementing it on gray scale images

@alper111
Copy link
Author

alper111 commented Jan 4, 2023

Since VGG expects 3-channeled input, it is necessary to extend the grayscale image to three channels.

@uroojz
Copy link

uroojz commented Jan 4, 2023

when I implement this code on my problem that have grayscale images(MNIST) it gives around 95 percent loss how to handle this?can i share my code here?I am reconstructing the MNIST images using autoencoder and wants to use VGGperceptualLoss

@uroojz
Copy link

uroojz commented Jan 5, 2023

hi there,
I want to imply this loss function for image reconstruction using autoencoder on MNIST dataset, when I implement this loss function for that particular task it gives me totally blurred images, but when it apply it without using perceptual loss I get clear reconstructed images,can anybody help me in this regard as i want to apply perceptual loss and want to get good result in this project.
blur output

@cdalinghaus
Copy link

My grayscale image data had no explicit color channel, so I've added a small check for that:

# Input is greyscale and of shape (batch, x, y) instead of (batch, 1, x, y)
# Add a color dimension
if len(input.shape) == 3:
    input = input.unsqueeze(1)
    target = target.unsqueeze(1)

Also, to remove the deprecation warning:

blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT).features[:4].eval())
blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT).features[4:9].eval())
blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT).features[9:16].eval())blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT).features[16:23].eval())

@chiehwangs
Copy link

Very useful tool! I am very confused. When I use

blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())

the GPU usage will rise sharply in the middle of training, and it will suddenly increase by about 7G!

But when I use

blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT, ).features[:4].eval())
blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT, ).features[4:9].eval())
blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT, ).features[9:16].eval())
blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT, ).features[16:23].eval())

does not have this problem?

That's weird, can someone tell me why?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment