Created December 21, 2020 20:45
Implementation of 2d Perceptual Loss utilizing base losses from torch.nn
Perceptual Loss Toy Implemntation from
import os.path as osp
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
__all__ = ["PerceptualLoss2D"]
"vgg16":["features.3", "features.8", "features.15", "features.22"],
"vgg16_bn":["features.5", "features.12", "features.21", "features.32"],
"vgg19":["features.3", "features.8", "features.17", "features.26"],
"vgg19_bn":["features.5", "features.12", "features.25", "features.38"],
"resnet18":["layer1.1", "layer2.1", "layer3.1", "layer4.1"],
"resnet34":["layer1.2", "layer2.3", "layer3.5", "layer4.2"],
"resnet50":["layer1.2", "layer2.3", "layer3.5", "layer4.2"],
"resnet101":["layer1.2", "layer2.3", "layer3.22", "layer4.2"],
class PerceptualLoss2D(nn.Module):
returns list of losses between input and target over layers
variations on
All optional arguments
model_name (str [vgg16_bn]) pretrained model over which losses are computed
if no default in OUTPUTS[model_name], outputs arg reqd
reduction (str, [mean]): Specifies the reduction to apply to the base_loss
outputs (tuple, list) list of layer names to compute loss over
default defined in OUTPUTS[model_name]
checkpoint (str ["pretrained]) # path to checkpoint if requre
size (tuple (224,224)) # if None, no resizing
mean (tuple (0.485, 0.456, 0.406)) # if None, image is not mean centered
std (tuple (0.229, 0.224, 0.225)) # if None, image is not mean centered
base_loss (str [MSELoss]) in ["L1Loss", "MSELoss", "KLDivLoss","PoissonNLLLoss",
"HingeEmbeddingLoss" "MultiLabelSoftMarginLoss", "SmoothL1Loss"]
debug (bool [False]) if True latent outputs are copied to self.debug[]
>>> base_loss="SmoothL1Loss"
>>> PL = an.PerceptualLoss2D(base_loss=base_loss, debug=True).cuda()
>>> with torch.no_grad():
>>> losses = PL(input_cuda_tensor, target_cuda_tensor)
>>> for l in PL.debug: # if debug True: outputs are copied to debug
>>> print(l.shape)
# pylint: disable = no-member
def __init__(self, model_name="vgg16_bn", reduction: str = 'mean', outputs: list = None,
checkpoint: str = "pretrained", size: tuple = (224, 224),
mean: tuple = (0.485, 0.456, 0.406), std: tuple = (0.229, 0.224, 0.225),
base_loss: str = "MSELoss", debug: bool = False) -> None:
super(PerceptualLoss2D, self).__init__()
self.debug = debug
self.mean = None
self.std = None
if mean is not None:
self.mean = nn.Parameter(torch.as_tensor(mean).reshape(1, len(mean), 1, 1))
self.register_parameter(name="mean", param=self.mean)
if std is not None:
self.std = nn.Parameter(torch.as_tensor(std).reshape(1, len(std), 1, 1))
self.register_parameter(name="std", param=self.std)
self.size = size
self.loss_fn = None
self._set_loss(base_loss, reduction)
if outputs is None:
assert model_name in OUTPUTS, "only %s models configured"%str(list(OUTPUTS.keys()))
outputs = OUTPUTS[model_name]
# model
pretrained = True if not osp.isfile(checkpoint) else False
self.model = models.__dict__[model_name](pretrained=pretrained)
if osp.isfile(checkpoint):
state_dict = torch.load(checkpoint)#, map_location=device)
if "checkpoint" in state_dict.keys():
state_dict = state_dict["checkpoint"]
self.latent = Latents()
self.handles = []
_valid_outputs = [o[0] for o in self.model.named_modules()]
for _out in outputs:
assert _out in _valid_outputs, "Requested invalid layer %s of %s"%(_out, str(_valid_outputs))
handle = dict(self.model.named_modules())[_out].register_forward_hook(self.latent)
def _set_loss(self, base_loss: str, reduction: str = "mean") -> None:
# _losses = [l for l in nn.__dict__ if "Loss" in l and l[0].isupper() and not
# l.startswith("__") and callable(nn.__dict__[l])]
# Loss functions in nn that accepting same type and size of input and target
_losses = ["L1Loss", "MSELoss", "KLDivLoss", "PoissonNLLLoss",
"HingeEmbeddingLoss", "MultiLabelSoftMarginLoss", "SmoothL1Loss"]
assert base_loss in _losses, "%s not found in %s"%(base_loss, str(_losses))
self.loss_fn = nn.__dict__[base_loss](reduction=reduction)
def forward(self, input: Tensor, target: Tensor) -> list:
# Mean Center if requested
if self.mean is not None and self.std is not None:
_x = input.sub(self.mean).div(self.std)
_y = target.sub(self.mean).div(self.std)
# Resize if requested
if self.size is not None:
_x = F.interpolate(_x, size=self.size, mode="bilinear", align_corners=False)
# Resize target if not equal to input
if _x.shape[-2:] != _y.shape[-2:]:
_y = F.interpolate(_y, size=_x.shape[-2:], mode="bilinear", align_corners=False)
# Run model on intput target together
_out = self.model([_x, _y], dim=0))
losses = [self.loss_fn(l[:len(_x)], l[len(_x):]) for l in self.latent.outputs]
if self.debug:
self.debug = [l.cpu().clone().detach() for l in self.latent.outputs]
return losses #sum(losses)
class Latents:
def __init__(self):
self.outputs = []
def __call__(self, module, module_in, module_out):
def clear(self):
self.outputs = []
