Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active May 22, 2023 20:12
Show Gist options
  • Save crowsonkb/a7c2b7916665debeb7647c15e1111f70 to your computer and use it in GitHub Desktop.
Save crowsonkb/a7c2b7916665debeb7647c15e1111f70 to your computer and use it in GitHub Desktop.
"""Vectorial TV loss using higher accuracy order finite difference operators."""
import torch
FINITE_DIFFERENCE_COEFFS = {
1: torch.tensor([-1, 1]),
2: torch.tensor([-3 / 2, 2, -1 / 2]),
3: torch.tensor([-11 / 6, 3, -3 / 2, 1 / 3]),
4: torch.tensor([-25 / 12, 4, -3, 4 / 3, -1 / 4]),
5: torch.tensor([-137 / 60, 5, -5, 10 / 3, -5 / 4, 1 / 5]),
6: torch.tensor([-49 / 20, 6, -15 / 2, 20 / 3, -15 / 4, 6 / 5, -1 / 6]),
}
def tv_loss(x, order=1, beta=1.0, eps=1e-8, reduction="sum"):
"""Vectorial TV loss using higher accuracy order finite difference operators.
Total variation loss is the sum of the magnitudes of the spatial first derivatives
of the image, evaluated at each pixel. The first derivatives are computed using
finite difference operators. This routine supports forward differences of order 1
to 6. The bottom and right boundaries are treated by reducing the accuracy order of
the stencil to the maximum possible at that point.
Total variation loss is from "Nonlinear total variation based noise removal
algorithms", Rudin et al (1992), and the vectorial (color) variant used here is from
"Color TV: total variation methods for restoration of vector-valued images",
Blomgren et al (1998). The "beta" parameter (exponent) is from "Understanding Deep
Image Representations by Inverting Them", Mahendran et al (2014).
Args:
x: Input tensor of shape (N, C, H, W).
order: Order of accuracy of the finite difference operator. Must be an integer
between 1 and 6.
beta: Exponent of the loss function. 1.0 is the standard TV loss. 2.0 is an
L2 variant of it that is often wrongly called "TV loss" in machine learning
code.
eps: Small constant to avoid NaN gradients. Set it to 1e-5 if you are using
fp16 outside of autocast.
reduction: Type of reduction to apply to the loss. Can be "sum", "mean", or
"none". "none" returns a tensor of shape (N, H, W) with the loss for each
pixel.
Returns:
The TV loss or losses.
"""
n, c, h, w = x.shape
kernel = torch.zeros([order, order + 1])
for i in range(order):
kernel[i, : i + 2] = FINITE_DIFFERENCE_COEFFS[i + 1]
kernel = kernel.to(x)
kx = torch.tile(kernel[:, None, None, :], (c, 1, 1, 1))
ky = kx.transpose(2, 3)
x_for_dx = torch.nn.functional.pad(x, (0, order, 0, 0))
x_for_dy = torch.nn.functional.pad(x, (0, 0, 0, order))
dx = torch.nn.functional.conv2d(x_for_dx, kx, groups=c)
dy = torch.nn.functional.conv2d(x_for_dy, ky, groups=c)
dx = dx.reshape((n, c, order, h, -1))
dy = dy.reshape((n, c, order, -1, w))
select_x = x.new_zeros([order, w])
select_y = x.new_zeros([order, h])
for i in range(order - 1):
select_x[i, w - i - 2] = 1
select_y[i, h - i - 2] = 1
select_x[order - 1, : w - order] = 1
select_y[order - 1, : h - order] = 1
dx = torch.einsum("ncohw,ow->nchw", dx, select_x)
dy = torch.einsum("ncohw,oh->nchw", dy, select_y)
sq_norms = torch.sum(dx**2 + dy**2, dim=1)
losses = torch.pow(sq_norms + eps, beta / 2) - eps ** (beta / 2)
if reduction == "sum":
return torch.sum(losses)
elif reduction == "mean":
return torch.mean(losses)
elif reduction == "none":
return losses
else:
raise ValueError("Unknown reduction type.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment