"""Vectorial TV loss using higher accuracy order finite difference operators."""
import torch
1: torch.tensor([-1, 1]),
2: torch.tensor([-3 / 2, 2, -1 / 2]),
3: torch.tensor([-11 / 6, 3, -3 / 2, 1 / 3]),
4: torch.tensor([-25 / 12, 4, -3, 4 / 3, -1 / 4]),
5: torch.tensor([-137 / 60, 5, -5, 10 / 3, -5 / 4, 1 / 5]),
6: torch.tensor([-49 / 20, 6, -15 / 2, 20 / 3, -15 / 4, 6 / 5, -1 / 6]),
def tv_loss(x, order=1, beta=1.0, eps=1e-8, reduction="sum"):
"""Vectorial TV loss using higher accuracy order finite difference operators.
Total variation loss is the sum of the magnitudes of the spatial first derivatives
of the image, evaluated at each pixel. The first derivatives are computed using
finite difference operators. This routine supports forward differences of order 1
to 6. The bottom and right boundaries are treated by reducing the accuracy order of
the stencil to the maximum possible at that point.
Total variation loss is from "Nonlinear total variation based noise removal
algorithms", Rudin et al (1992), and the vectorial (color) variant used here is from
"Color TV: total variation methods for restoration of vector-valued images",
Blomgren et al (1998). The "beta" parameter (exponent) is from "Understanding Deep
Image Representations by Inverting Them", Mahendran et al (2014).
x: Input tensor of shape (N, C, H, W).
order: Order of accuracy of the finite difference operator. Must be an integer
between 1 and 6.
beta: Exponent of the loss function. 1.0 is the standard TV loss. 2.0 is an
L2 variant of it that is often wrongly called "TV loss" in machine learning
eps: Small constant to avoid NaN gradients. Set it to 1e-5 if you are using
fp16 outside of autocast.
reduction: Type of reduction to apply to the loss. Can be "sum", "mean", or
"none". "none" returns a tensor of shape (N, H, W) with the loss for each
The TV loss or losses.
n, c, h, w = x.shape
kernel = torch.zeros([order, order + 1])
for i in range(order):
kernel[i, : i + 2] = FINITE_DIFFERENCE_COEFFS[i + 1]
kernel =
kx = torch.tile(kernel[:, None, None, :], (c, 1, 1, 1))
ky = kx.transpose(2, 3)
x_for_dx = torch.nn.functional.pad(x, (0, order, 0, 0))
x_for_dy = torch.nn.functional.pad(x, (0, 0, 0, order))
dx = torch.nn.functional.conv2d(x_for_dx, kx, groups=c)
dy = torch.nn.functional.conv2d(x_for_dy, ky, groups=c)
dx = dx.reshape((n, c, order, h, -1))
dy = dy.reshape((n, c, order, -1, w))
select_x = x.new_zeros([order, w])
select_y = x.new_zeros([order, h])
for i in range(order - 1):
select_x[i, w - i - 2] = 1
select_y[i, h - i - 2] = 1
select_x[order - 1, : w - order] = 1
select_y[order - 1, : h - order] = 1
dx = torch.einsum("ncohw,ow->nchw", dx, select_x)
dy = torch.einsum("ncohw,oh->nchw", dy, select_y)
sq_norms = torch.sum(dx**2 + dy**2, dim=1)
losses = torch.pow(sq_norms + eps, beta / 2) - eps ** (beta / 2)
if reduction == "sum":
return torch.sum(losses)
elif reduction == "mean":
return torch.mean(losses)
elif reduction == "none":
return losses
raise ValueError("Unknown reduction type.")
