Skip to content

Instantly share code, notes, and snippets.

@alper111
Last active April 10, 2024 02:21
Show Gist options
  • Star 111 You must be signed in to star a gist
  • Fork 6 You must be signed in to fork a gist
  • Save alper111/8233cdb0414b4cb5853f2f730ab95a49 to your computer and use it in GitHub Desktop.
Save alper111/8233cdb0414b4cb5853f2f730ab95a49 to your computer and use it in GitHub Desktop.
PyTorch implementation of VGG perceptual loss
import torch
import torchvision
class VGGPerceptualLoss(torch.nn.Module):
def __init__(self, resize=True):
super(VGGPerceptualLoss, self).__init__()
blocks = []
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
for bl in blocks:
for p in bl.parameters():
p.requires_grad = False
self.blocks = torch.nn.ModuleList(blocks)
self.transform = torch.nn.functional.interpolate
self.resize = resize
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]):
if input.shape[1] != 3:
input = input.repeat(1, 3, 1, 1)
target = target.repeat(1, 3, 1, 1)
input = (input-self.mean) / self.std
target = (target-self.mean) / self.std
if self.resize:
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
loss = 0.0
x = input
y = target
for i, block in enumerate(self.blocks):
x = block(x)
y = block(y)
if i in feature_layers:
loss += torch.nn.functional.l1_loss(x, y)
if i in style_layers:
act_x = x.reshape(x.shape[0], x.shape[1], -1)
act_y = y.reshape(y.shape[0], y.shape[1], -1)
gram_x = act_x @ act_x.permute(0, 2, 1)
gram_y = act_y @ act_y.permute(0, 2, 1)
loss += torch.nn.functional.l1_loss(gram_x, gram_y)
return loss
@uroojz
Copy link

uroojz commented Jan 4, 2023

def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]):
if input.shape[1] != 3:
input = input.repeat(1, 3, 1, 1)
target = target.repeat(1, 3, 1, 1)
hi there,
is it necessary to check this condition if we have gray scale image i.e 1 channel,I am a beginner so got little bit confusion when implementing it on gray scale images

@alper111
Copy link
Author

alper111 commented Jan 4, 2023

Since VGG expects 3-channeled input, it is necessary to extend the grayscale image to three channels.

@uroojz
Copy link

uroojz commented Jan 4, 2023

when I implement this code on my problem that have grayscale images(MNIST) it gives around 95 percent loss how to handle this?can i share my code here?I am reconstructing the MNIST images using autoencoder and wants to use VGGperceptualLoss

@uroojz
Copy link

uroojz commented Jan 5, 2023

hi there,
I want to imply this loss function for image reconstruction using autoencoder on MNIST dataset, when I implement this loss function for that particular task it gives me totally blurred images, but when it apply it without using perceptual loss I get clear reconstructed images,can anybody help me in this regard as i want to apply perceptual loss and want to get good result in this project.
blur output

@cdalinghaus
Copy link

My grayscale image data had no explicit color channel, so I've added a small check for that:

# Input is greyscale and of shape (batch, x, y) instead of (batch, 1, x, y)
# Add a color dimension
if len(input.shape) == 3:
    input = input.unsqueeze(1)
    target = target.unsqueeze(1)

Also, to remove the deprecation warning:

blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT).features[:4].eval())
blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT).features[4:9].eval())
blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT).features[9:16].eval())blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT).features[16:23].eval())

@chiehwangs
Copy link

Very useful tool! I am very confused. When I use

blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())

the GPU usage will rise sharply in the middle of training, and it will suddenly increase by about 7G!

But when I use

blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT, ).features[:4].eval())
blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT, ).features[4:9].eval())
blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT, ).features[9:16].eval())
blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT, ).features[16:23].eval())

does not have this problem?

That's weird, can someone tell me why?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment