Skip to content

Instantly share code, notes, and snippets.

@alper111
Last active April 10, 2024 02:21
Show Gist options
  • Star 111 You must be signed in to star a gist
  • Fork 6 You must be signed in to fork a gist
  • Save alper111/8233cdb0414b4cb5853f2f730ab95a49 to your computer and use it in GitHub Desktop.
Save alper111/8233cdb0414b4cb5853f2f730ab95a49 to your computer and use it in GitHub Desktop.
PyTorch implementation of VGG perceptual loss
import torch
import torchvision
class VGGPerceptualLoss(torch.nn.Module):
def __init__(self, resize=True):
super(VGGPerceptualLoss, self).__init__()
blocks = []
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
for bl in blocks:
for p in bl.parameters():
p.requires_grad = False
self.blocks = torch.nn.ModuleList(blocks)
self.transform = torch.nn.functional.interpolate
self.resize = resize
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]):
if input.shape[1] != 3:
input = input.repeat(1, 3, 1, 1)
target = target.repeat(1, 3, 1, 1)
input = (input-self.mean) / self.std
target = (target-self.mean) / self.std
if self.resize:
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
loss = 0.0
x = input
y = target
for i, block in enumerate(self.blocks):
x = block(x)
y = block(y)
if i in feature_layers:
loss += torch.nn.functional.l1_loss(x, y)
if i in style_layers:
act_x = x.reshape(x.shape[0], x.shape[1], -1)
act_y = y.reshape(y.shape[0], y.shape[1], -1)
gram_x = act_x @ act_x.permute(0, 2, 1)
gram_y = act_y @ act_y.permute(0, 2, 1)
loss += torch.nn.functional.l1_loss(gram_x, gram_y)
return loss
@alper111
Copy link
Author

Hi @zhengwjie. Since this is implemented as a torch.nn.Module, you can initialize the loss module and move it to the corresponding gpu:

vgg_loss = VGGPerceptualLoss()
vgg_loss.to("cuda:0")  # or cuda:1, cuda:2 ...

I haven't tried this at the moment, but it should work because I was using this module to train a model in GPU.

@sheyining
Copy link

Thanks for your work. In the original paper https://arxiv.org/abs/1603.08155), they used l2 loss for the "Feature Reconstruction Loss", and use the squared Frobenius norm for "Style Reconstruction Loss". But you are using l1_loss for both loss computations. Could you please explain why you use l1_loss? Shouldn't they be fixed?

@JacopoMangiavacchi
Copy link

Thanks for your work. I've just added the capacity to weight the layers and documented usage of this loss on a style transfer scenario: https://medium.com/@JMangia/optimize-a-face-to-cartoon-style-transfer-model-trained-quickly-on-small-style-dataset-and-50594126e792

@siarheidevel
Copy link

One question:

gram_x = act_x @ act_x.permute(0, 2, 1)
gram_y = act_y @ act_y.permute(0, 2, 1)

Maybe you need to normalize gram matrices by dividing by number of elements:

b,c,h,w = x.shape
gram_x = act_x @ act_x.permute(0, 2, 1) / (c*h*w)
gram_y = act_y @ act_y.permute(0, 2, 1) / (c*h*w)

@alex-vasilchenko-md
Copy link

alex-vasilchenko-md commented Jun 1, 2022

Thanks for your nice implementation!

I refactored it a little bit while I was reviewing how it works:

https://gist.github.com/alex-vasilchenko-md/dc5155f96f73fc4f67afffcb74f635e0

@alex-vasilchenko-md
Copy link

Hi @zhengwjie. Since this is implemented as a torch.nn.Module, you can initialize the loss module and move it to the corresponding gpu:

vgg_loss = VGGPerceptualLoss()
vgg_loss.to("cuda:0")  # or cuda:1, cuda:2 ...

I haven't tried this at the moment, but it should work because I was using this module to train a model in GPU.

it worked for me when I trained my model on GPU. without this I had an issue like this:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

@alper111
Copy link
Author

Thanks for your work. In the original paper https://arxiv.org/abs/1603.08155), they used l2 loss for the "Feature Reconstruction Loss", and use the squared Frobenius norm for "Style Reconstruction Loss". But you are using l1_loss for both loss computations. Could you please explain why you use l1_loss? Shouldn't they be fixed?

Thanks for the interest @sheyining. On my specific application, L1 was working better. Other than that, I have no specific motivation to choose L1 over L2.

Thanks for your work. I've just added the capacity to weight the layers and documented usage of this loss on a style transfer scenario: https://medium.com/@JMangia/optimize-a-face-to-cartoon-style-transfer-model-trained-quickly-on-small-style-dataset-and-50594126e792

Thanks for the interest. A good blog post!

One question:

gram_x = act_x @ act_x.permute(0, 2, 1)
gram_y = act_y @ act_y.permute(0, 2, 1)

Maybe you need to normalize gram matrices by dividing by number of elements:

b,c,h,w = x.shape
gram_x = act_x @ act_x.permute(0, 2, 1) / (c*h*w)
gram_y = act_y @ act_y.permute(0, 2, 1) / (c*h*w)

@siarheidevel Indeed, we can normalize them. Again, on my specific application, it was better not to normalize it.

@uroojz
Copy link

uroojz commented Jan 4, 2023

def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]):
if input.shape[1] != 3:
input = input.repeat(1, 3, 1, 1)
target = target.repeat(1, 3, 1, 1)
hi there,
is it necessary to check this condition if we have gray scale image i.e 1 channel,I am a beginner so got little bit confusion when implementing it on gray scale images

@alper111
Copy link
Author

alper111 commented Jan 4, 2023

Since VGG expects 3-channeled input, it is necessary to extend the grayscale image to three channels.

@uroojz
Copy link

uroojz commented Jan 4, 2023

when I implement this code on my problem that have grayscale images(MNIST) it gives around 95 percent loss how to handle this?can i share my code here?I am reconstructing the MNIST images using autoencoder and wants to use VGGperceptualLoss

@uroojz
Copy link

uroojz commented Jan 5, 2023

hi there,
I want to imply this loss function for image reconstruction using autoencoder on MNIST dataset, when I implement this loss function for that particular task it gives me totally blurred images, but when it apply it without using perceptual loss I get clear reconstructed images,can anybody help me in this regard as i want to apply perceptual loss and want to get good result in this project.
blur output

@cdalinghaus
Copy link

My grayscale image data had no explicit color channel, so I've added a small check for that:

# Input is greyscale and of shape (batch, x, y) instead of (batch, 1, x, y)
# Add a color dimension
if len(input.shape) == 3:
    input = input.unsqueeze(1)
    target = target.unsqueeze(1)

Also, to remove the deprecation warning:

blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT).features[:4].eval())
blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT).features[4:9].eval())
blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT).features[9:16].eval())blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT).features[16:23].eval())

@chiehwangs
Copy link

Very useful tool! I am very confused. When I use

blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())

the GPU usage will rise sharply in the middle of training, and it will suddenly increase by about 7G!

But when I use

blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT, ).features[:4].eval())
blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT, ).features[4:9].eval())
blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT, ).features[9:16].eval())
blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT, ).features[16:23].eval())

does not have this problem?

That's weird, can someone tell me why?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment