Last active
May 22, 2023 20:12
-
-
Save crowsonkb/a7c2b7916665debeb7647c15e1111f70 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Vectorial TV loss using higher accuracy order finite difference operators.""" | |
import torch | |
FINITE_DIFFERENCE_COEFFS = { | |
1: torch.tensor([-1, 1]), | |
2: torch.tensor([-3 / 2, 2, -1 / 2]), | |
3: torch.tensor([-11 / 6, 3, -3 / 2, 1 / 3]), | |
4: torch.tensor([-25 / 12, 4, -3, 4 / 3, -1 / 4]), | |
5: torch.tensor([-137 / 60, 5, -5, 10 / 3, -5 / 4, 1 / 5]), | |
6: torch.tensor([-49 / 20, 6, -15 / 2, 20 / 3, -15 / 4, 6 / 5, -1 / 6]), | |
} | |
def tv_loss(x, order=1, beta=1.0, eps=1e-8, reduction="sum"): | |
"""Vectorial TV loss using higher accuracy order finite difference operators. | |
Total variation loss is the sum of the magnitudes of the spatial first derivatives | |
of the image, evaluated at each pixel. The first derivatives are computed using | |
finite difference operators. This routine supports forward differences of order 1 | |
to 6. The bottom and right boundaries are treated by reducing the accuracy order of | |
the stencil to the maximum possible at that point. | |
Total variation loss is from "Nonlinear total variation based noise removal | |
algorithms", Rudin et al (1992), and the vectorial (color) variant used here is from | |
"Color TV: total variation methods for restoration of vector-valued images", | |
Blomgren et al (1998). The "beta" parameter (exponent) is from "Understanding Deep | |
Image Representations by Inverting Them", Mahendran et al (2014). | |
Args: | |
x: Input tensor of shape (N, C, H, W). | |
order: Order of accuracy of the finite difference operator. Must be an integer | |
between 1 and 6. | |
beta: Exponent of the loss function. 1.0 is the standard TV loss. 2.0 is an | |
L2 variant of it that is often wrongly called "TV loss" in machine learning | |
code. | |
eps: Small constant to avoid NaN gradients. Set it to 1e-5 if you are using | |
fp16 outside of autocast. | |
reduction: Type of reduction to apply to the loss. Can be "sum", "mean", or | |
"none". "none" returns a tensor of shape (N, H, W) with the loss for each | |
pixel. | |
Returns: | |
The TV loss or losses. | |
""" | |
n, c, h, w = x.shape | |
kernel = torch.zeros([order, order + 1]) | |
for i in range(order): | |
kernel[i, : i + 2] = FINITE_DIFFERENCE_COEFFS[i + 1] | |
kernel = kernel.to(x) | |
kx = torch.tile(kernel[:, None, None, :], (c, 1, 1, 1)) | |
ky = kx.transpose(2, 3) | |
x_for_dx = torch.nn.functional.pad(x, (0, order, 0, 0)) | |
x_for_dy = torch.nn.functional.pad(x, (0, 0, 0, order)) | |
dx = torch.nn.functional.conv2d(x_for_dx, kx, groups=c) | |
dy = torch.nn.functional.conv2d(x_for_dy, ky, groups=c) | |
dx = dx.reshape((n, c, order, h, -1)) | |
dy = dy.reshape((n, c, order, -1, w)) | |
select_x = x.new_zeros([order, w]) | |
select_y = x.new_zeros([order, h]) | |
for i in range(order - 1): | |
select_x[i, w - i - 2] = 1 | |
select_y[i, h - i - 2] = 1 | |
select_x[order - 1, : w - order] = 1 | |
select_y[order - 1, : h - order] = 1 | |
dx = torch.einsum("ncohw,ow->nchw", dx, select_x) | |
dy = torch.einsum("ncohw,oh->nchw", dy, select_y) | |
sq_norms = torch.sum(dx**2 + dy**2, dim=1) | |
losses = torch.pow(sq_norms + eps, beta / 2) - eps ** (beta / 2) | |
if reduction == "sum": | |
return torch.sum(losses) | |
elif reduction == "mean": | |
return torch.mean(losses) | |
elif reduction == "none": | |
return losses | |
else: | |
raise ValueError("Unknown reduction type.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment