Skip to content

Instantly share code, notes, and snippets.

@fzimmermann89
Last active April 2, 2024 11:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fzimmermann89/3e7275df0ae15349f01137eaf116c054 to your computer and use it in GitHub Desktop.
Save fzimmermann89/3e7275df0ae15349f01137eaf116c054 to your computer and use it in GitHub Desktop.
adjoint linear operator for torch.grid_sample
class AdjointGridSample(torch.autograd.Function):
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
y: torch.Tensor,
grid: torch.Tensor,
xshape: Sequence[int],
interpolation_mode: Literal['bilinear', 'nearest', 'bicubic'] = 'bilinear',
padding_mode: Literal['zeros', 'border', 'reflection'] = 'zeros',
align_corners: bool = True,
) -> torch.Tensor:
match interpolation_mode:
case 'bilinear':
mode_enum = 0
case 'nearest':
mode_enum = 1
case 'bicubic':
mode_enum = 2
case _:
raise ValueError(f'Interpolation mode {interpolation_mode} not supported')
match padding_mode:
case 'zeros':
padding_mode_enum = 0
case 'border':
padding_mode_enum = 1
case 'reflection':
padding_mode_enum = 2
case _:
raise ValueError(f'Padding mode {padding_mode} not supported')
match dim := grid.shape[-1]:
case 3:
backward_2d_or_3d = torch.ops.aten.grid_sampler_3d_backward
case 2:
backward_2d_or_3d = torch.ops.aten.grid_sampler_2d_backward
case _:
raise ValueError(f'only 2d and 3d supported, not {dim}')
if y.shape[0] != grid.shape[0]:
raise ValueError(f'y and grid must have same batch size, got {y.shape=}, {grid.shape=}')
if xshape[1] != y.shape[1]:
raise ValueError(f'xshape and y must have same number of channels, got {xshape[1]} and {y.shape[1]}.')
if len(xshape) - 2 != dim:
raise ValueError(f'len(xshape) and dim must either both bei 2 or 3, got {len(xshape)} and {dim}')
# These are required in the backward
ctx.xshape = xshape # type: ignore[attr-defined]
ctx.interpolation_mode = mode_enum # type: ignore[attr-defined]
ctx.padding_mode = padding_mode_enum # type: ignore[attr-defined]
ctx.align_corners = align_corners # type: ignore[attr-defined]
ctx.backward_2d_or_3d = backward_2d_or_3d # type: ignore[attr-defined]
if grid.requires_grad:
# only if we need to calculate the gradient for grid we need y
ctx.save_for_backward(grid, y)
else:
ctx.save_for_backward(grid)
shape_dummy = torch.empty(1, dtype=y.dtype, device=y.device).broadcast_to(xshape)
x = backward_2d_or_3d(
y,
shape_dummy,
grid,
interpolation_mode=mode_enum,
padding_mode=padding_mode_enum,
align_corners=align_corners,
output_mask=[True, False],
)[0]
return x
@staticmethod
def backward(
ctx: torch.autograd.function.FunctionCtx, *grad_output: torch.Tensor
) -> tuple[torch.Tensor | None, torch.Tensor | None, None, None, None, None]:
"""Backward of the Adjoint Gridsample Operator."""
need_y_grad, need_grid_grad, *_ = ctx.needs_input_grad # type: ignore[attr-defined]
grid = ctx.saved_tensors[0] # type: ignore[attr-defined]
if need_y_grad:
grad_y = torch.grid_sampler(
grad_output[0],
grid,
ctx.interpolation_mode, # type: ignore[attr-defined]
ctx.padding_mode, # type: ignore[attr-defined]
ctx.align_corners, # type: ignore[attr-defined]
)
else:
grad_y = None
if need_grid_grad:
y = ctx.saved_tensors[1] # type: ignore[attr-defined]
grad_grid = ctx.backward_2d_or_3d( # type: ignore[attr-defined]
y,
grad_output[0],
grid,
interpolation_mode=ctx.interpolation_mode, # type: ignore[attr-defined]
padding_mode=ctx.padding_mode, # type: ignore[attr-defined]
align_corners=ctx.align_corners, # type: ignore[attr-defined]
output_mask=[False, True],
)[1]
else:
grad_grid = None
return grad_y, grad_grid, None, None, None, None
def adjoint_grid_sample(
y: torch.Tensor,
grid: torch.Tensor,
xshape: Sequence[int],
interpolation_mode: Literal["bilinear", "nearest", "bicubic"] = "bilinear",
padding_mode: Literal["zeros", "border", "reflection"] = "zeros",
align_corners: bool = True,
) -> torch.Tensor:
"""Adjoint of the linear operator x->gridsample(x,grid).
Parameters
----------
y
tensor in the range of gridsample(x,grid). Should not include batch or channel dimension.
grid
grid in the shape (*y.shape, 2/3)
xshape
shape of the domain of gridsample(x,grid), i.e. the shape of x
interpolation_mode
the kind of interpolation used
padding_mode
how to pad the input
align_corners
if True, the corner pixels of the input and output tensors are aligned,
and thus preserving the values at those pixels
"""
return AdjointGridSample.apply(y, grid, xshape, interpolation_mode, padding_mode, align_corners)
@fzimmermann89
Copy link
Author

The trick is knowing that the operator is linear and does not depend in the input but only uses it to get the range of the adjoint.
So we can just provide xshape and create a dummy tensor with the correct shape to use the adjoint in an interactive reconstruction.

@fzimmermann89
Copy link
Author

This should now also work with autograd!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment