Skip to content

Instantly share code, notes, and snippets.

@arquolo
Last active May 25, 2021 13:17
Show Gist options
  • Save arquolo/2cc04195e9b1597df1f453fbd383b037 to your computer and use it in GitHub Desktop.
Save arquolo/2cc04195e9b1597df1f453fbd383b037 to your computer and use it in GitHub Desktop.
Fork from cheind/py-thin-plate-spline
from __future__ import annotations
import numpy as np
class TPS:
@staticmethod
def fit(c: np.ndarray, delta: np.ndarray, lambd: float = 0., reduced: bool = False) -> np.ndarray:
# c: (N, 2), delta: (N, 2) -> (N + 2, 2) or (N + 3, 2)
n = len(c)
k = TPS.ud(c, c)
k += np.eye(n, dtype='f4') * lambd # (N, N)
p = np.ones((n, 3), dtype='f4')
p[:, 1:] = c
a = np.zeros((n + 3, n + 3), dtype='f4')
a[:n, :n] = k
a[:n, -3:] = p
a[-3:, :n] = p.T
v = np.zeros((n + 3, 2), dtype='f4')
v[:n] = delta
theta = np.linalg.solve(a, v) # (n + 3, 2)
return theta[1:] if reduced else theta
@staticmethod
def ud(a: np.ndarray, b: np.ndarray) -> np.ndarray:
# a: (N, 2), b: (M, 2) -> (N, M)
r2 = np.square(a[:, None] - b[None, :]).sum(-1)
return 0.5 * r2 * np.log(r2 + 1e-12)
@staticmethod
def ud2(dy: np.ndarray, dx: np.ndarray, dst: np.ndarray) -> np.ndarray:
# dy: (H, 1, 1), dx: (1, W, 1), dst: (N, 2)
r2y = np.square(dy - dst[..., 1]) # (H, 1, N)
r2x = np.square(dx - dst[..., 0]) # (1, W, N)
r2 = r2y + r2x # (H, W, N)
return 0.5 * r2 * np.log(r2 + 1e-12)
@staticmethod
def z(dy: np.ndarray, dx: np.ndarray, dst: np.ndarray, theta: np.ndarray) -> np.ndarray:
# dy: (H, 1, 1), dx: (1, W, 1), dst: (N, 2), theta (N+3?, 2) -> (H, W, 2)
u = TPS.ud2(dy, dx, dst) # (H, W, N)
w = theta[:-3] # (N, 2)
if theta.shape[0] == dst.shape[0] + 2:
w = np.concatenate((-w.sum(0, keepdims=True), w), axis=0)
b = np.dot(u, w) # (H, W, 2)
a0, a1, a2 = theta[-3:] # (3, 2)
# (2) + (2) x (H, 1, 1) + (2) x (1, W, 1) + (H, W, 2)
return a0 + a1 * dy + a2 * dx + b
def tps_warp(src, dst, src_hw: tuple[int, int], dst_hw: tuple[int, int],
reduced: bool = False,
pool: int = 1) -> tuple[np.ndarray, np.ndarray]:
# src: (N, 2), dst: (N, 2) -> pair of (H', W')
# To trade precision via performance, use pool > 1
theta = TPS.fit(dst, src - dst, reduced=reduced)
h2, w2 = dst_hw
h2p, w2p = h2 // pool, w2 // pool # Use lower scale for perf
dy = np.linspace(0, 1, h2p, dtype='f4')[:, None, None] # (H', 1, 1)
dx = np.linspace(0, 1, w2p, dtype='f4')[None, :, None] # (1, W', 1)
grid = TPS.z(dy, dx, dst, theta) # (H', W', 2)
# Restore offset
grid[..., 1] += dy.squeeze(-1)
grid[..., 0] += dx.squeeze(-1)
# Restore scale
if pool > 1:
grid = cv2.resize(grid, (w2, h2), interpolation=cv2.INTER_CUBIC)
h1, w1 = src_hw
my = (grid[..., 1] * h1).astype('f4')
mx = (grid[..., 0] * w1).astype('f4')
return mx, my
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment