Created
December 21, 2020 20:45
-
-
Save xvdp/5e07af559aee0c6a53fab0ae2482ef60 to your computer and use it in GitHub Desktop.
Implementation of 2d Perceptual Loss utilizing base losses from torch.nn
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Perceptual Loss Toy Implemntation from https://arxiv.org/pdf/1603.08155.pdf | |
""" | |
import os.path as osp | |
import torch | |
from torch import Tensor | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import models | |
__all__ = ["PerceptualLoss2D"] | |
OUTPUTS = { | |
"vgg16":["features.3", "features.8", "features.15", "features.22"], | |
"vgg16_bn":["features.5", "features.12", "features.21", "features.32"], | |
"vgg19":["features.3", "features.8", "features.17", "features.26"], | |
"vgg19_bn":["features.5", "features.12", "features.25", "features.38"], | |
"resnet18":["layer1.1", "layer2.1", "layer3.1", "layer4.1"], | |
"resnet34":["layer1.2", "layer2.3", "layer3.5", "layer4.2"], | |
"resnet50":["layer1.2", "layer2.3", "layer3.5", "layer4.2"], | |
"resnet101":["layer1.2", "layer2.3", "layer3.22", "layer4.2"], | |
} | |
class PerceptualLoss2D(nn.Module): | |
""" | |
returns list of losses between input and target over layers | |
variations on https://arxiv.org/pdf/1603.08155.pdf | |
All optional arguments | |
model_name (str [vgg16_bn]) pretrained model over which losses are computed | |
if no default in .loss.py OUTPUTS[model_name], outputs arg reqd | |
reduction (str, [mean]): Specifies the reduction to apply to the base_loss | |
outputs (tuple, list) list of layer names to compute loss over | |
default defined in .loss.py OUTPUTS[model_name] | |
checkpoint (str ["pretrained]) # path to checkpoint if requre | |
size (tuple (224,224)) # if None, no resizing | |
mean (tuple (0.485, 0.456, 0.406)) # if None, image is not mean centered | |
std (tuple (0.229, 0.224, 0.225)) # if None, image is not mean centered | |
base_loss (str [MSELoss]) in ["L1Loss", "MSELoss", "KLDivLoss","PoissonNLLLoss", | |
"HingeEmbeddingLoss" "MultiLabelSoftMarginLoss", "SmoothL1Loss"] | |
debug (bool [False]) if True latent outputs are copied to self.debug[] | |
Example | |
>>> base_loss="SmoothL1Loss" | |
>>> PL = an.PerceptualLoss2D(base_loss=base_loss, debug=True).cuda() | |
>>> with torch.no_grad(): | |
>>> losses = PL(input_cuda_tensor, target_cuda_tensor) | |
>>> for l in PL.debug: # if debug True: outputs are copied to debug | |
>>> print(l.shape) | |
""" | |
# pylint: disable = no-member | |
def __init__(self, model_name="vgg16_bn", reduction: str = 'mean', outputs: list = None, | |
checkpoint: str = "pretrained", size: tuple = (224, 224), | |
mean: tuple = (0.485, 0.456, 0.406), std: tuple = (0.229, 0.224, 0.225), | |
base_loss: str = "MSELoss", debug: bool = False) -> None: | |
super(PerceptualLoss2D, self).__init__() | |
self.debug = debug | |
self.mean = None | |
self.std = None | |
if mean is not None: | |
self.mean = nn.Parameter(torch.as_tensor(mean).reshape(1, len(mean), 1, 1)) | |
self.register_parameter(name="mean", param=self.mean) | |
if std is not None: | |
self.std = nn.Parameter(torch.as_tensor(std).reshape(1, len(std), 1, 1)) | |
self.register_parameter(name="std", param=self.std) | |
self.size = size | |
self.loss_fn = None | |
self._set_loss(base_loss, reduction) | |
if outputs is None: | |
assert model_name in OUTPUTS, "only %s models configured"%str(list(OUTPUTS.keys())) | |
outputs = OUTPUTS[model_name] | |
# model | |
pretrained = True if not osp.isfile(checkpoint) else False | |
self.model = models.__dict__[model_name](pretrained=pretrained) | |
if osp.isfile(checkpoint): | |
state_dict = torch.load(checkpoint)#, map_location=device) | |
if "checkpoint" in state_dict.keys(): | |
state_dict = state_dict["checkpoint"] | |
self.model.load_state_dict(state_dict) | |
self.model.eval() | |
self.latent = Latents() | |
self.handles = [] | |
_valid_outputs = [o[0] for o in self.model.named_modules()] | |
for _out in outputs: | |
assert _out in _valid_outputs, "Requested invalid layer %s of %s"%(_out, str(_valid_outputs)) | |
handle = dict(self.model.named_modules())[_out].register_forward_hook(self.latent) | |
self.handles.append(handle) | |
def _set_loss(self, base_loss: str, reduction: str = "mean") -> None: | |
# _losses = [l for l in nn.__dict__ if "Loss" in l and l[0].isupper() and not | |
# l.startswith("__") and callable(nn.__dict__[l])] | |
# Loss functions in nn that accepting same type and size of input and target | |
_losses = ["L1Loss", "MSELoss", "KLDivLoss", "PoissonNLLLoss", | |
"HingeEmbeddingLoss", "MultiLabelSoftMarginLoss", "SmoothL1Loss"] | |
assert base_loss in _losses, "%s not found in %s"%(base_loss, str(_losses)) | |
self.loss_fn = nn.__dict__[base_loss](reduction=reduction) | |
def forward(self, input: Tensor, target: Tensor) -> list: | |
# Mean Center if requested | |
if self.mean is not None and self.std is not None: | |
_x = input.sub(self.mean).div(self.std) | |
_y = target.sub(self.mean).div(self.std) | |
# Resize if requested | |
if self.size is not None: | |
_x = F.interpolate(_x, size=self.size, mode="bilinear", align_corners=False) | |
# Resize target if not equal to input | |
if _x.shape[-2:] != _y.shape[-2:]: | |
_y = F.interpolate(_y, size=_x.shape[-2:], mode="bilinear", align_corners=False) | |
# Run model on intput target together | |
_out = self.model(torch.cat([_x, _y], dim=0)) | |
losses = [self.loss_fn(l[:len(_x)], l[len(_x):]) for l in self.latent.outputs] | |
if self.debug: | |
self.debug = [l.cpu().clone().detach() for l in self.latent.outputs] | |
self.latent.clear() | |
return losses #sum(losses) | |
class Latents: | |
def __init__(self): | |
self.outputs = [] | |
def __call__(self, module, module_in, module_out): | |
self.outputs.append(module_out) | |
def clear(self): | |
self.outputs = [] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment