Skip to content

Instantly share code, notes, and snippets.

@xvdp
Created December 21, 2020 20:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xvdp/5e07af559aee0c6a53fab0ae2482ef60 to your computer and use it in GitHub Desktop.
Save xvdp/5e07af559aee0c6a53fab0ae2482ef60 to your computer and use it in GitHub Desktop.
Implementation of 2d Perceptual Loss utilizing base losses from torch.nn
"""
Perceptual Loss Toy Implemntation from https://arxiv.org/pdf/1603.08155.pdf
"""
import os.path as osp
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
__all__ = ["PerceptualLoss2D"]
OUTPUTS = {
"vgg16":["features.3", "features.8", "features.15", "features.22"],
"vgg16_bn":["features.5", "features.12", "features.21", "features.32"],
"vgg19":["features.3", "features.8", "features.17", "features.26"],
"vgg19_bn":["features.5", "features.12", "features.25", "features.38"],
"resnet18":["layer1.1", "layer2.1", "layer3.1", "layer4.1"],
"resnet34":["layer1.2", "layer2.3", "layer3.5", "layer4.2"],
"resnet50":["layer1.2", "layer2.3", "layer3.5", "layer4.2"],
"resnet101":["layer1.2", "layer2.3", "layer3.22", "layer4.2"],
}
class PerceptualLoss2D(nn.Module):
"""
returns list of losses between input and target over layers
variations on https://arxiv.org/pdf/1603.08155.pdf
All optional arguments
model_name (str [vgg16_bn]) pretrained model over which losses are computed
if no default in .loss.py OUTPUTS[model_name], outputs arg reqd
reduction (str, [mean]): Specifies the reduction to apply to the base_loss
outputs (tuple, list) list of layer names to compute loss over
default defined in .loss.py OUTPUTS[model_name]
checkpoint (str ["pretrained]) # path to checkpoint if requre
size (tuple (224,224)) # if None, no resizing
mean (tuple (0.485, 0.456, 0.406)) # if None, image is not mean centered
std (tuple (0.229, 0.224, 0.225)) # if None, image is not mean centered
base_loss (str [MSELoss]) in ["L1Loss", "MSELoss", "KLDivLoss","PoissonNLLLoss",
"HingeEmbeddingLoss" "MultiLabelSoftMarginLoss", "SmoothL1Loss"]
debug (bool [False]) if True latent outputs are copied to self.debug[]
Example
>>> base_loss="SmoothL1Loss"
>>> PL = an.PerceptualLoss2D(base_loss=base_loss, debug=True).cuda()
>>> with torch.no_grad():
>>> losses = PL(input_cuda_tensor, target_cuda_tensor)
>>> for l in PL.debug: # if debug True: outputs are copied to debug
>>> print(l.shape)
"""
# pylint: disable = no-member
def __init__(self, model_name="vgg16_bn", reduction: str = 'mean', outputs: list = None,
checkpoint: str = "pretrained", size: tuple = (224, 224),
mean: tuple = (0.485, 0.456, 0.406), std: tuple = (0.229, 0.224, 0.225),
base_loss: str = "MSELoss", debug: bool = False) -> None:
super(PerceptualLoss2D, self).__init__()
self.debug = debug
self.mean = None
self.std = None
if mean is not None:
self.mean = nn.Parameter(torch.as_tensor(mean).reshape(1, len(mean), 1, 1))
self.register_parameter(name="mean", param=self.mean)
if std is not None:
self.std = nn.Parameter(torch.as_tensor(std).reshape(1, len(std), 1, 1))
self.register_parameter(name="std", param=self.std)
self.size = size
self.loss_fn = None
self._set_loss(base_loss, reduction)
if outputs is None:
assert model_name in OUTPUTS, "only %s models configured"%str(list(OUTPUTS.keys()))
outputs = OUTPUTS[model_name]
# model
pretrained = True if not osp.isfile(checkpoint) else False
self.model = models.__dict__[model_name](pretrained=pretrained)
if osp.isfile(checkpoint):
state_dict = torch.load(checkpoint)#, map_location=device)
if "checkpoint" in state_dict.keys():
state_dict = state_dict["checkpoint"]
self.model.load_state_dict(state_dict)
self.model.eval()
self.latent = Latents()
self.handles = []
_valid_outputs = [o[0] for o in self.model.named_modules()]
for _out in outputs:
assert _out in _valid_outputs, "Requested invalid layer %s of %s"%(_out, str(_valid_outputs))
handle = dict(self.model.named_modules())[_out].register_forward_hook(self.latent)
self.handles.append(handle)
def _set_loss(self, base_loss: str, reduction: str = "mean") -> None:
# _losses = [l for l in nn.__dict__ if "Loss" in l and l[0].isupper() and not
# l.startswith("__") and callable(nn.__dict__[l])]
# Loss functions in nn that accepting same type and size of input and target
_losses = ["L1Loss", "MSELoss", "KLDivLoss", "PoissonNLLLoss",
"HingeEmbeddingLoss", "MultiLabelSoftMarginLoss", "SmoothL1Loss"]
assert base_loss in _losses, "%s not found in %s"%(base_loss, str(_losses))
self.loss_fn = nn.__dict__[base_loss](reduction=reduction)
def forward(self, input: Tensor, target: Tensor) -> list:
# Mean Center if requested
if self.mean is not None and self.std is not None:
_x = input.sub(self.mean).div(self.std)
_y = target.sub(self.mean).div(self.std)
# Resize if requested
if self.size is not None:
_x = F.interpolate(_x, size=self.size, mode="bilinear", align_corners=False)
# Resize target if not equal to input
if _x.shape[-2:] != _y.shape[-2:]:
_y = F.interpolate(_y, size=_x.shape[-2:], mode="bilinear", align_corners=False)
# Run model on intput target together
_out = self.model(torch.cat([_x, _y], dim=0))
losses = [self.loss_fn(l[:len(_x)], l[len(_x):]) for l in self.latent.outputs]
if self.debug:
self.debug = [l.cpu().clone().detach() for l in self.latent.outputs]
self.latent.clear()
return losses #sum(losses)
class Latents:
def __init__(self):
self.outputs = []
def __call__(self, module, module_in, module_out):
self.outputs.append(module_out)
def clear(self):
self.outputs = []
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment