Instantly share code, notes, and snippets.

# Birch-san/slerp.py

Last active February 24, 2024 12:11
Show Gist options
• Save Birch-san/230ac46f99ec411ed5907b0a3d728efa to your computer and use it in GitHub Desktop.
PyTorch implementation of spherical linear interpolation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 from torch import FloatTensor, LongTensor, Tensor, Size, lerp, zeros_like from torch.linalg import norm # adapted to PyTorch from: # https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c # most of the extra complexity is to support: # - many-dimensional vectors # - v0 or v1 with last dim all zeroes, or v0 ~colinear with v1 # - falls back to lerp() # - conditional logic implemented with parallelism rather than Python loops # - many-dimensional tensor for t # - you can ask for batches of slerp outputs by making t more-dimensional than the vectors # - slerp( # v0: torch.Size([2,3]), # v1: torch.Size([2,3]), # t: torch.Size([4,1,1]), # ) # - this makes it interface-compatible with lerp() def slerp(v0: FloatTensor, v1: FloatTensor, t: float|FloatTensor, DOT_THRESHOLD=0.9995): ''' Spherical linear interpolation Args: v0: Starting vector v1: Final vector t: Float value between 0.0 and 1.0 DOT_THRESHOLD: Threshold for considering the two vectors as colinear. Not recommended to alter this. Returns: Interpolation vector between v0 and v1 ''' assert v0.shape == v1.shape, "shapes of v0 and v1 must match" # Normalize the vectors to get the directions and angles v0_norm: FloatTensor = norm(v0, dim=-1) v1_norm: FloatTensor = norm(v1, dim=-1) v0_normed: FloatTensor = v0 / v0_norm.unsqueeze(-1) v1_normed: FloatTensor = v1 / v1_norm.unsqueeze(-1) # Dot product with the normalized vectors dot: FloatTensor = (v0_normed * v1_normed).sum(-1) dot_mag: FloatTensor = dot.abs() # if dp is NaN, it's because the v0 or v1 row was filled with 0s # If absolute value of dot product is almost 1, vectors are ~colinear, so use lerp gotta_lerp: LongTensor = dot_mag.isnan() | (dot_mag > DOT_THRESHOLD) can_slerp: LongTensor = ~gotta_lerp t_batch_dim_count: int = max(0, t.dim()-v0.dim()) if isinstance(t, Tensor) else 0 t_batch_dims: Size = t.shape[:t_batch_dim_count] if isinstance(t, Tensor) else Size([]) out: FloatTensor = zeros_like(v0.expand(*t_batch_dims, *[-1]*v0.dim())) # if no elements are lerpable, our vectors become 0-dimensional, preventing broadcasting if gotta_lerp.any(): lerped: FloatTensor = lerp(v0, v1, t) out: FloatTensor = lerped.where(gotta_lerp.unsqueeze(-1), out) # if no elements are slerpable, our vectors become 0-dimensional, preventing broadcasting if can_slerp.any(): # Calculate initial angle between v0 and v1 theta_0: FloatTensor = dot.arccos().unsqueeze(-1) sin_theta_0: FloatTensor = theta_0.sin() # Angle at timestep t theta_t: FloatTensor = theta_0 * t sin_theta_t: FloatTensor = theta_t.sin() # Finish the slerp algorithm s0: FloatTensor = (theta_0 - theta_t).sin() / sin_theta_0 s1: FloatTensor = sin_theta_t / sin_theta_0 slerped: FloatTensor = s0 * v0 + s1 * v1 out: FloatTensor = slerped.where(can_slerp.unsqueeze(-1), out) return out

### Birch-san commented Feb 25, 2023

Example invocation:

```from torch import FloatTensor
import torch

device=torch.device('cuda')
dtype=torch.float16

# spherical linear midpoint between [1,0] and [0,1] is [sine(pi/4), sine(pi/4)]
start: FloatTensor = tensor([1,0], dtype=dtype, device=device)
end: FloatTensor = tensor([0,1], dtype=dtype, device=device)
slerp(start, end, 0.5)
# tensor([0.7070, 0.7070], device='cuda:0', dtype=torch.float16)

# many-dimensional data supported; interpolation is performed row-wise
start: FloatTensor = tensor([[1,0], [2,0]], dtype=dtype, device=device)
end: FloatTensor = tensor([[0,1], [0,2]], dtype=dtype, device=device)
slerp(start, end, 0.5)
# tensor([[0.7070, 0.7070],
#         [1.4141, 1.4141]], device='cuda:0', dtype=torch.float16)

# any row where either vector has all zeroes, is computed via fallback to lerp():
start: FloatTensor = tensor([[0,0], [1,0]], dtype=dtype, device=device)
end: FloatTensor = tensor([[1,1], [0,1]], dtype=dtype, device=device)
slerp(start, end, time)
# tensor([[0.5000, 0.5000],
#         [0.7070, 0.7070]], device='cuda:0', dtype=torch.float16)

# any row where both vectors are approx. colinear, is computed via fallback to lerp():
start: FloatTensor = tensor([1,1], dtype=dtype, device=device)
end: FloatTensor = tensor([2,2], dtype=dtype, device=device)
slerp(start, end, 0.5)
# tensor([1.5000, 1.5000], device='cuda:0', dtype=torch.float16)

# we can compute slerp at a batch of timepoints simultaneously,
# by providing a time tensor that is more-dimensional than start/end
start: FloatTensor = tensor([1,0], dtype=dtype, device=device)
end: FloatTensor = tensor([0,1], dtype=dtype, device=device)
time: FloatTensor = tensor([[0.25], [0.5]], dtype=torch.float16, device=device)
slerp(start, end, time)
# tensor([[0.9238, 0.3826],
#         [0.7070, 0.7070]], device='cuda:0', dtype=torch.float16)```