Last active
January 6, 2023 23:02
-
-
Save VSehwag/33db3092c4d82867ddff6e3682de26e4 to your computer and use it in GitHub Desktop.
We provide a blazing-fast implementation of bicubic resampling for PyTorch-XLA (which only supports nearest/bilinear resampling)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
A standalone PyTorch implementation for fast and efficient bicubic resampling. | |
Well suited for Pytorch-XLA, which doesn't support pytorch native bicubic resampling, | |
i.e., only nearest/bilinear are lowered for xla. As of now, bicubic resampling is very | |
slow on TPUs with PyTorch XLA. | |
This implementation dramatically reduces this overhead and (almost) makes it as fast as | |
on a gpu. | |
## Hacked by: Vikash Sehwag | |
## Original author: Sanghyun Son | |
## Last update: Jan 6th, 2022 (EST) | |
Depencency: torch | |
Example:: | |
>>> import torch | |
>>> import core | |
>>> x = torch.arange(16).float().view(1, 1, 4, 4) | |
>>> y = core.imresize(x, sizes=(3, 3)) | |
>>> print(y) | |
tensor([[[[ 0.7506, 2.1004, 3.4503], | |
[ 6.1505, 7.5000, 8.8499], | |
[11.5497, 12.8996, 14.2494]]]]) | |
''' | |
import math | |
import typing | |
import time | |
from tqdm import tqdm | |
import torch | |
import torch.nn.functional as F | |
from torch.nn import functional as F | |
__all__ = ['imresize'] | |
_I = typing.Optional[int] | |
_D = typing.Optional[torch.dtype] | |
def cubic_contribution(x: torch.Tensor, a: float = -0.5) -> torch.Tensor: | |
ax = x.abs() | |
ax2 = ax * ax | |
ax3 = ax * ax2 | |
range_01 = ax.le(1) | |
range_12 = torch.logical_and(ax.gt(1), ax.le(2)) | |
cont_01 = (a + 2) * ax3 - (a + 3) * ax2 + 1 | |
cont_01 = cont_01 * range_01.to(dtype=x.dtype) | |
cont_12 = (a * ax3) - (5 * a * ax2) + (8 * a * ax) - (4 * a) | |
cont_12 = cont_12 * range_12.to(dtype=x.dtype) | |
cont = cont_01 + cont_12 | |
return cont | |
def gaussian_contribution(x: torch.Tensor, sigma: float = 2.0) -> torch.Tensor: | |
range_3sigma = (x.abs() <= 3 * sigma + 1) | |
# Normalization will be done after | |
cont = torch.exp(-x.pow(2) / (2 * sigma**2)) | |
cont = cont * range_3sigma.to(dtype=x.dtype) | |
return cont | |
def reflect_padding(x: torch.Tensor, dim: int, pad_pre: int, | |
pad_post: int) -> torch.Tensor: | |
""" | |
Apply reflect padding to the given Tensor. | |
Note that it is slightly different from the PyTorch functional.pad, | |
where boundary elements are used only once. | |
Instead, we follow the MATLAB implementation | |
which uses boundary elements twice. | |
For example, | |
[a, b, c, d] would become [b, a, b, c, d, c] with the PyTorch implementation, | |
while our implementation yields [a, a, b, c, d, d]. | |
""" | |
b, c, h, w = x.size() | |
if dim == 2 or dim == -2: | |
padding_buffer = x.new_zeros(b, c, h + pad_pre + pad_post, w) | |
padding_buffer[..., pad_pre:(h + pad_pre), :].copy_(x) | |
for p in range(pad_pre): | |
padding_buffer[..., pad_pre - p - 1, :].copy_(x[..., p, :]) | |
for p in range(pad_post): | |
padding_buffer[..., h + pad_pre + p, :].copy_(x[..., -(p + 1), :]) | |
else: | |
padding_buffer = x.new_zeros(b, c, h, w + pad_pre + pad_post) | |
padding_buffer[..., pad_pre:(w + pad_pre)].copy_(x) | |
for p in range(pad_pre): | |
padding_buffer[..., pad_pre - p - 1].copy_(x[..., p]) | |
for p in range(pad_post): | |
padding_buffer[..., w + pad_pre + p].copy_(x[..., -(p + 1)]) | |
return padding_buffer | |
def padding(x: torch.Tensor, | |
dim: int, | |
pad_pre: int, | |
pad_post: int, | |
padding_type: typing.Optional[str] = 'reflect') -> torch.Tensor: | |
if padding_type is None: | |
return x | |
elif padding_type == 'reflect': | |
x_pad = reflect_padding(x, dim, pad_pre, pad_post) | |
else: | |
raise ValueError('{} padding is not supported!'.format(padding_type)) | |
return x_pad | |
def get_padding(base: torch.Tensor, kernel_size: int, | |
x_size: int) -> typing.Tuple[int, int, torch.Tensor]: | |
base = base.long() | |
r_min = base.min() | |
r_max = base.max() + kernel_size - 1 | |
if r_min <= 0: | |
pad_pre = -r_min | |
pad_pre = pad_pre.item() | |
base += pad_pre | |
else: | |
pad_pre = 0 | |
if r_max >= x_size: | |
pad_post = r_max - x_size + 1 | |
pad_post = pad_post.item() | |
else: | |
pad_post = 0 | |
return pad_pre, pad_post, base | |
def get_weight(dist: torch.Tensor, | |
kernel_size: int, | |
kernel: str = 'cubic', | |
sigma: float = 2.0, | |
antialiasing_factor: float = 1) -> torch.Tensor: | |
buffer_pos = dist.new_zeros(kernel_size, len(dist)) | |
for idx, buffer_sub in enumerate(buffer_pos): | |
buffer_sub.copy_(dist - idx) | |
# Expand (downsampling) / Shrink (upsampling) the receptive field. | |
buffer_pos *= antialiasing_factor | |
if kernel == 'cubic': | |
weight = cubic_contribution(buffer_pos) | |
elif kernel == 'gaussian': | |
weight = gaussian_contribution(buffer_pos, sigma=sigma) | |
else: | |
raise ValueError('{} kernel is not supported!'.format(kernel)) | |
weight /= weight.sum(dim=0, keepdim=True) | |
return weight | |
def unfold1d_hacked(x, kernel): | |
"""Super specifc to bicubic with kernel shape of (n, 1) or (1, n) | |
Only works for default padding-0, dilation-1, stride-1 | |
Logic: Lets say we have [6, 1] kernel and input shape is [N, C, 128, 32]. | |
Our first implementation simply divided 128 in 123 strides and loop over them. | |
This version makes it faster by just looping over the kernel size, i.e., 6. We simply | |
slice a contigous block of x and split it into kernel size. E.g., at index=0 we slice | |
a 126 length block and split it into 21 blocks. At index=5, we can only accomodate 20 blocks | |
so we slice a 120 length block starting from 5th element. Rest is just figuring our what | |
indices these slices will occupy in the final output tensor. | |
Note: All indices, e.g., torch.arange, must be generated on the current device itself. | |
Don't generate them on cpu, as it will slow it by 1.5-2x. | |
""" | |
assert len(kernel) == 2, "only 2-d kernels supported" | |
assert (kernel[0] == 1) or (kernel[1] | |
== 1), "kernel has to [1, N] or [N, 1]" | |
# since core logic is written for [N, 1] kernel, permuate data for [1, N] kernel | |
if kernel[0] == 1: | |
kernel = list(reversed(kernel)) | |
x = x.permute(0, 1, 3, 2) | |
transposed = True | |
else: | |
transposed = False | |
n, c, h, w = x.shape | |
n1, n2 = h - (kernel[0] - 1), w - (kernel[1] - 1) | |
new_shape = (n, c * kernel[0] * kernel[1], n1 * n2) | |
output = torch.zeros(new_shape, device=x.device) | |
ks = kernel[0] | |
for i in range(ks): | |
splits = [math.ceil((n1 - i) / ks) for i in range(ks)] | |
vals = x[:, :, i:i + splits[i] * ks, :].reshape( | |
n, c, splits[i], kernel[0], | |
n2).permute(0, 1, 3, 2, 4).reshape(n, c * kernel[0], -1) | |
yrange = lambda idx, size: torch.arange( | |
size * idx, size * (idx + 1), device=x.device) | |
idx = torch.cat([yrange(i + j * ks, n2) for j in range(splits[i])]) | |
output[:, :, idx] = vals | |
if transposed: | |
idx = torch.arange(n1 * n2, | |
device=x.device).reshape(n1, | |
n2).permute(1, | |
0).reshape(-1) | |
output = output[:, :, idx] | |
return output | |
def reshape_tensor(x: torch.Tensor, dim: int, | |
kernel_size: int) -> torch.Tensor: | |
# Resize height | |
if dim == 2 or dim == -2: | |
k = (kernel_size, 1) | |
h_out = x.size(-2) - kernel_size + 1 | |
w_out = x.size(-1) | |
# Resize width | |
else: | |
k = (1, kernel_size) | |
h_out = x.size(-2) | |
w_out = x.size(-1) - kernel_size + 1 | |
unfold = unfold1d_hacked(x, k) | |
unfold = unfold.view(unfold.size(0), -1, h_out, w_out) | |
return unfold | |
def reshape_input( | |
x: torch.Tensor) -> typing.Tuple[torch.Tensor, _I, _I, _I, _I]: | |
if x.dim() == 4: | |
b, c, h, w = x.size() | |
elif x.dim() == 3: | |
c, h, w = x.size() | |
b = None | |
elif x.dim() == 2: | |
h, w = x.size() | |
b = c = None | |
else: | |
raise ValueError('{}-dim Tensor is not supported!'.format(x.dim())) | |
x = x.view(-1, 1, h, w) | |
return x, b, c, h, w | |
def reshape_output(x: torch.Tensor, b: _I, c: _I) -> torch.Tensor: | |
rh = x.size(-2) | |
rw = x.size(-1) | |
# Back to the original dimension | |
if b is not None: | |
x = x.view(b, c, rh, rw) # 4-dim | |
else: | |
if c is not None: | |
x = x.view(c, rh, rw) # 3-dim | |
else: | |
x = x.view(rh, rw) # 2-dim | |
return x | |
def cast_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _D]: | |
if x.dtype != torch.float32 or x.dtype != torch.float64: | |
dtype = x.dtype | |
x = x.float() | |
else: | |
dtype = None | |
return x, dtype | |
def cast_output(x: torch.Tensor, dtype: _D) -> torch.Tensor: | |
if dtype is not None: | |
if not dtype.is_floating_point: | |
x = x.round() | |
# To prevent over/underflow when converting types | |
if dtype is torch.uint8: | |
x = x.clamp(0, 255) | |
x = x.to(dtype=dtype) | |
return x | |
def resize_1d(x: torch.Tensor, | |
dim: int, | |
size: typing.Optional[int], | |
scale: typing.Optional[float], | |
kernel: str = 'cubic', | |
sigma: float = 2.0, | |
padding_type: str = 'reflect', | |
antialiasing: bool = True) -> torch.Tensor: | |
""" | |
Args: | |
x (torch.Tensor): A torch.Tensor of dimension (B x C, 1, H, W). | |
dim (int): | |
scale (float): | |
size (int): | |
Return: | |
""" | |
# Identity case | |
if scale == 1: | |
return x | |
# Default bicubic kernel with antialiasing (only when downsampling) | |
if kernel == 'cubic': | |
kernel_size = 4 | |
else: | |
kernel_size = math.floor(6 * sigma) | |
if antialiasing and (scale < 1): | |
antialiasing_factor = scale | |
kernel_size = math.ceil(kernel_size / antialiasing_factor) | |
else: | |
antialiasing_factor = 1 | |
# We allow margin to both sizes | |
kernel_size += 2 | |
# Weights only depend on the shape of input and output, | |
# so we do not calculate gradients here. | |
with torch.no_grad(): | |
pos = torch.linspace( | |
0, | |
size - 1, | |
steps=size, | |
dtype=x.dtype, | |
device=x.device, | |
) | |
pos = (pos + 0.5) / scale - 0.5 | |
base = pos.floor() - (kernel_size // 2) + 1 | |
dist = pos - base | |
weight = get_weight( | |
dist, | |
kernel_size, | |
kernel=kernel, | |
sigma=sigma, | |
antialiasing_factor=antialiasing_factor, | |
) | |
pad_pre, pad_post, base = get_padding(base, kernel_size, x.size(dim)) | |
##print("1", x.shape) | |
# To backpropagate through x | |
x_pad = padding(x, dim, pad_pre, pad_post, padding_type=padding_type) | |
##print("2", x_pad.shape) | |
unfold = reshape_tensor(x_pad, dim, kernel_size) | |
##print("3", unfold.shape) | |
# Subsampling first | |
if dim == 2 or dim == -2: | |
sample = unfold[..., base, :] | |
weight = weight.view(1, kernel_size, sample.size(2), 1) | |
else: | |
sample = unfold[..., base] | |
weight = weight.view(1, kernel_size, 1, sample.size(3)) | |
# Apply the kernel | |
x = sample * weight | |
x = x.sum(dim=1, keepdim=True) | |
##print("4", x.shape) | |
return x | |
def downsampling_2d(x: torch.Tensor, | |
k: torch.Tensor, | |
scale: int, | |
padding_type: str = 'reflect') -> torch.Tensor: | |
c = x.size(1) | |
k_h = k.size(-2) | |
k_w = k.size(-1) | |
k = k.to(dtype=x.dtype, device=x.device) | |
k = k.view(1, 1, k_h, k_w) | |
k = k.repeat(c, c, 1, 1) | |
e = torch.eye(c, dtype=k.dtype, device=k.device, requires_grad=False) | |
e = e.view(c, c, 1, 1) | |
k = k * e | |
pad_h = (k_h - scale) // 2 | |
pad_w = (k_w - scale) // 2 | |
x = padding(x, -2, pad_h, pad_h, padding_type=padding_type) | |
x = padding(x, -1, pad_w, pad_w, padding_type=padding_type) | |
y = F.conv2d(x, k, padding=0, stride=scale) | |
return y | |
def imresize(x: torch.Tensor, | |
scale: typing.Optional[float] = None, | |
size: typing.Optional[typing.Tuple[int, int]] = None, | |
kernel: typing.Union[str, torch.Tensor] = 'cubic', | |
sigma: float = 2, | |
rotation_degree: float = 0, | |
padding_type: str = 'reflect', | |
antialiasing: bool = False) -> torch.Tensor: | |
""" | |
Args: | |
x (torch.Tensor): | |
scale (float): | |
sizes (tuple(int, int)): | |
kernel (str, default='cubic'): | |
sigma (float, default=2): | |
rotation_degree (float, default=0): | |
padding_type (str, default='reflect'): | |
antialiasing (bool, default=True): | |
Return: | |
torch.Tensor: | |
""" | |
if scale is None and size is None: | |
raise ValueError('One of scale or sizes must be specified!') | |
if scale is not None and size is not None: | |
raise ValueError('Please specify scale or sizes to avoid conflict!') | |
x, b, c, h, w = reshape_input(x) | |
if size is None: | |
# Determine output size | |
sizes = (math.ceil(h * scale), math.ceil(w * scale)) | |
scales = (scale, scale) | |
if scale is None: | |
scales = (size[0] / h, size[1] / w) | |
x, dtype = cast_input(x) | |
if isinstance(kernel, str): | |
# Shared keyword arguments across dimensions | |
kwargs = { | |
'kernel': kernel, | |
'sigma': sigma, | |
'padding_type': padding_type, | |
'antialiasing': antialiasing, | |
} | |
# Core resizing module | |
x = resize_1d(x, -2, size=size[0], scale=scales[0], **kwargs) | |
x = resize_1d(x, -1, size=size[1], scale=scales[1], **kwargs) | |
elif isinstance(kernel, torch.Tensor): | |
x = downsampling_2d(x, kernel, scale=int(1 / scale)) | |
x = reshape_output(x, b, c) | |
x = cast_output(x, dtype) | |
return x |
Following is an even faster way to implement unfold, but unfortunately it uses torch.as_strided(...), which is not lowered on xla. Thus this implementation being faster on gpu, end up being much slower on xla (tpus). If xla supports torch.as_strided(..), we can simply plug this in the code above. Check out https://jott.live/markdown/as_strided to learn more about as_stride.
def unfold_fast(x, kernel):
# default padding-0, dilation-1, stride-1
n, c, h, w = x.shape
n1, n2 = h - (kernel[0] - 1), w - (kernel[1] - 1)
new_shape = (n, c * kernel[0] * kernel[1], n1 * n2)
x = x.contiguous() # to hopefully avoid corrupted output
stride = x.stride()
out = torch.as_strided(x, (n, c, n1, n2) + tuple(kernel), stride[:2] + (w, 1, w, 1))
out = out.reshape(n, c, n1 * n2, kernel[0] * kernel[1]).permute(0, 1, 3, 2).reshape(n, -1, n1 * n2)
return out
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Only nearest/bilinear resampling is lowered (as of Jan'23 - pytorch/xla#1227) on PyTorch XLA, which makes bicubic resampling very slow, i.e., xla falls back to cpu ops which are super slow.
We build on the publicly available custom bicubic resampling from Sanghyun Son (https://github.com/sanghyun-son/bicubic_pytorch) and tailor it for XLA. Note that by itself this Sanghyun's implementation is slower than torch.functional.iterpolate(..., 'bicubic') but it can be tailored to suit the need for xla.
The key issue with bicubic on xla is that it uses torch.functional.unfold(...), which itself is not lowered. We hack this function by building it from native pytorch ops (ones that are lowered).