Skip to content

Instantly share code, notes, and snippets.

@a-canela
Created March 25, 2023 12:30
Show Gist options
  • Save a-canela/8e0ac530b5b9f1e4c445091d71d2b427 to your computer and use it in GitHub Desktop.
Save a-canela/8e0ac530b5b9f1e4c445091d71d2b427 to your computer and use it in GitHub Desktop.
2D / 3D Cubic (bicubic / tricubic) interpolations implemented with torch
import torch
def spline_2d_torch(points, Z):
K = 4
N, _ = points.shape
H, W, C = Z.shape
coords = (torch.clamp(points, -1.0, 1.0) + 1.0) * 0.5 * (Z.shape[0] - 1)
vcoords = coords.type(dtype=torch.long, non_blocking=True)
rcoords = coords - vcoords
Z0 = Z.new_zeros((H + 3, W + 3, C))
Z0[1:-2, 1:-2] = Z
y0 = rcoords[:, 0]
x0 = rcoords[:, 1]
yc = vcoords[:, 0].view(N, 1).repeat(1, K)
xc = vcoords[:, 1].view(N, 1).repeat(1, K) + torch.arange(
K, dtype=vcoords.dtype, device=vcoords.device).view(1, K)
zy = (
Z0[yc, xc] * (((-0.75 * y0 + 3.0) * (y0 + 1.0) - 6.0) * (y0 + 1.0) + 3.0).view(N, 1, 1)
+ Z0[yc + 1, xc] * ((1.25 * y0 - 2.25) * y0 * y0 + 1.0).view(N, 1, 1)
+ Z0[yc + 2, xc] * ((-1.25 * y0 - 1.0) * (1.0 - y0) * (1.0 - y0) + 1.0).view(N, 1, 1)
+ Z0[yc + 3, xc] * (((0.75 * y0 + 2.25) * (2.0 - y0) - 6.0) * (2.0 - y0)+3.0).view(N, 1, 1)
)
z = (
zy[:, 0] * (((-0.75 * x0 + 3.0) * (x0 + 1.0) - 6.0) * (x0 + 1.0) + 3.0).view(N, 1)
+ zy[:, 1] * ((1.25 * x0 - 2.25) * x0 * x0 + 1.0).view(N, 1)
+ zy[:, 2] * ((-1.25 * x0 - 1.0) * (1.0 - x0) * (1.0 - x0) + 1.0).view(N, 1)
+ zy[:, 3] * (((0.75 * x0 + 2.25) * (2.0 - x0) - 6.0) * (2.0 - x0) + 3.0).view(N, 1)
)
return z
def spline_3d_torch(points, Z):
K = 4
N, _ = points.shape
D, H, W, C = Z.shape
coords = (torch.clamp(points, -1.0, 1.0) + 1.0) * 0.5 * (Z.shape[0] - 1)
vcoords = coords.type(dtype=torch.long, non_blocking=True)
rcoords = coords - vcoords
Z0 = Z.new_zeros((D + 3, H + 3, W + 3, C))
Z0[1:-2, 1:-2, 1:-2] = Z
z0 = rcoords[:, 0]
y0 = rcoords[:, 1]
x0 = rcoords[:, 2]
zc = vcoords[:, 0].view(N, 1, 1).repeat(1, K, K)
yc = vcoords[:, 1].view(N, 1, 1).repeat(1, K, K) + torch.arange(
K, dtype=vcoords.dtype, device=vcoords.device).view(1, K, 1)
xc = vcoords[:, 2].view(N, 1, 1).repeat(1, K, K) + torch.arange(
K, dtype=vcoords.dtype, device=vcoords.device).view(1, K, 1).repeat(1, K, 1).view(1, K, K)
zz = (
Z0[zc, yc, xc] * (((-0.75 * z0 + 3.0) * (z0 + 1.0) - 6.0) * (z0 + 1.0)+3.0).view(N, 1, 1, 1)
+ Z0[zc + 1, yc, xc] * ((1.25 * z0 - 2.25) * z0 * z0 + 1.0).view(N, 1, 1, 1)
+ Z0[zc + 2, yc, xc] * ((-1.25 * z0 - 1.0) * (1.0 - z0) * (1.0 - z0) + 1.0).view(N, 1, 1, 1)
+ Z0[zc + 3, yc, xc] * (((0.75 * z0 + 2.25) * (2.0 - z0)-6.0)*(2.0-z0)+3.0).view(N, 1, 1, 1)
)
zy = (
zz[:, 0] * (((-0.75 * y0 + 3.0) * (y0 + 1.0) - 6.0) * (y0 + 1.0) + 3.0).view(N, 1, 1)
+ zz[:, 1] * ((1.25 * y0 - 2.25) * y0 * y0 + 1.0).view(N, 1, 1)
+ zz[:, 2] * ((-1.25 * y0 - 1.0) * (1.0 - y0) * (1.0 - y0) + 1.0).view(N, 1, 1)
+ zz[:, 3] * (((0.75 * y0 + 2.25) * (2.0 - y0) - 6.0) * (2.0 - y0)+3.0).view(N, 1, 1)
)
z = (
zy[:, 0] * (((-0.75 * x0 + 3.0) * (x0 + 1.0) - 6.0) * (x0 + 1.0) + 3.0).view(N, 1)
+ zy[:, 1] * ((1.25 * x0 - 2.25) * x0 * x0 + 1.0).view(N, 1)
+ zy[:, 2] * ((-1.25 * x0 - 1.0) * (1.0 - x0) * (1.0 - x0) + 1.0).view(N, 1)
+ zy[:, 3] * (((0.75 * x0 + 2.25) * (2.0 - x0) - 6.0) * (2.0 - x0) + 3.0).view(N, 1)
)
return z
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment