Skip to content

Instantly share code, notes, and snippets.

@rosinality
Created February 7, 2020 13:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rosinality/b190bf2f3428e614f4f1f401a5f668a3 to your computer and use it in GitHub Desktop.
Save rosinality/b190bf2f3428e614f4f1f401a5f668a3 to your computer and use it in GitHub Desktop.
Perceptual loss implementation sample
import torch
from torch import nn
from torchvision.models import vgg16, vgg16_bn, vgg19, vgg19_bn
class PerceptualLoss(nn.Module):
def __init__(self, arch, indices, weights, normalize=True, min_max=(-1, 1)):
super().__init__()
vgg = (
{'vgg16': vgg16, 'vgg16_bn': vgg16_bn, 'vgg19': vgg19, 'vgg19_bn': vgg19_bn}
.get(arch)(pretrained=True)
.features
)
for p in vgg.parameters():
p.requires_grad = False
self.slices = nn.ModuleList()
for i, j in zip([-1] + indices, indices + [None]):
if j is None:
break
self.slices.append(vgg[slice(i + 1, j + 1)])
self.loss = nn.L1Loss()
self.weights = weights
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
val_range = min_max[1] - min_max[0]
mean = mean * (val_range) + min_max[0]
std = std * val_range
self.register_buffer('mean', mean)
self.register_buffer('std', std)
self.normalize = normalize
def forward(self, input, target):
if self.normalize:
input = (input - self.mean) / self.std
target = (target - self.mean) / self.std
feat1 = []
feat2 = []
out = input
for layer in self.slices:
out = layer(out)
feat1.append(out)
out = target
for layer in self.slices:
out = layer(out)
feat2.append(out)
loss = 0
for w, f1, f2 in zip(self.weights, feat1, feat2):
loss += w * self.loss(f1, f2.detach())
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment