Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Last active February 24, 2024 12:11
Show Gist options
  • Star 10 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Birch-san/230ac46f99ec411ed5907b0a3d728efa to your computer and use it in GitHub Desktop.
Save Birch-san/230ac46f99ec411ed5907b0a3d728efa to your computer and use it in GitHub Desktop.
PyTorch implementation of spherical linear interpolation
from torch import FloatTensor, LongTensor, Tensor, Size, lerp, zeros_like
from torch.linalg import norm
# adapted to PyTorch from:
# https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c
# most of the extra complexity is to support:
# - many-dimensional vectors
# - v0 or v1 with last dim all zeroes, or v0 ~colinear with v1
# - falls back to lerp()
# - conditional logic implemented with parallelism rather than Python loops
# - many-dimensional tensor for t
# - you can ask for batches of slerp outputs by making t more-dimensional than the vectors
# - slerp(
# v0: torch.Size([2,3]),
# v1: torch.Size([2,3]),
# t: torch.Size([4,1,1]),
# )
# - this makes it interface-compatible with lerp()
def slerp(v0: FloatTensor, v1: FloatTensor, t: float|FloatTensor, DOT_THRESHOLD=0.9995):
'''
Spherical linear interpolation
Args:
v0: Starting vector
v1: Final vector
t: Float value between 0.0 and 1.0
DOT_THRESHOLD: Threshold for considering the two vectors as
colinear. Not recommended to alter this.
Returns:
Interpolation vector between v0 and v1
'''
assert v0.shape == v1.shape, "shapes of v0 and v1 must match"
# Normalize the vectors to get the directions and angles
v0_norm: FloatTensor = norm(v0, dim=-1)
v1_norm: FloatTensor = norm(v1, dim=-1)
v0_normed: FloatTensor = v0 / v0_norm.unsqueeze(-1)
v1_normed: FloatTensor = v1 / v1_norm.unsqueeze(-1)
# Dot product with the normalized vectors
dot: FloatTensor = (v0_normed * v1_normed).sum(-1)
dot_mag: FloatTensor = dot.abs()
# if dp is NaN, it's because the v0 or v1 row was filled with 0s
# If absolute value of dot product is almost 1, vectors are ~colinear, so use lerp
gotta_lerp: LongTensor = dot_mag.isnan() | (dot_mag > DOT_THRESHOLD)
can_slerp: LongTensor = ~gotta_lerp
t_batch_dim_count: int = max(0, t.dim()-v0.dim()) if isinstance(t, Tensor) else 0
t_batch_dims: Size = t.shape[:t_batch_dim_count] if isinstance(t, Tensor) else Size([])
out: FloatTensor = zeros_like(v0.expand(*t_batch_dims, *[-1]*v0.dim()))
# if no elements are lerpable, our vectors become 0-dimensional, preventing broadcasting
if gotta_lerp.any():
lerped: FloatTensor = lerp(v0, v1, t)
out: FloatTensor = lerped.where(gotta_lerp.unsqueeze(-1), out)
# if no elements are slerpable, our vectors become 0-dimensional, preventing broadcasting
if can_slerp.any():
# Calculate initial angle between v0 and v1
theta_0: FloatTensor = dot.arccos().unsqueeze(-1)
sin_theta_0: FloatTensor = theta_0.sin()
# Angle at timestep t
theta_t: FloatTensor = theta_0 * t
sin_theta_t: FloatTensor = theta_t.sin()
# Finish the slerp algorithm
s0: FloatTensor = (theta_0 - theta_t).sin() / sin_theta_0
s1: FloatTensor = sin_theta_t / sin_theta_0
slerped: FloatTensor = s0 * v0 + s1 * v1
out: FloatTensor = slerped.where(can_slerp.unsqueeze(-1), out)
return out
@Birch-san
Copy link
Author

Example invocation:

from torch import FloatTensor
import torch

device=torch.device('cuda')
dtype=torch.float16

# spherical linear midpoint between [1,0] and [0,1] is [sine(pi/4), sine(pi/4)]
start: FloatTensor = tensor([1,0], dtype=dtype, device=device)
end: FloatTensor = tensor([0,1], dtype=dtype, device=device)
slerp(start, end, 0.5)
# tensor([0.7070, 0.7070], device='cuda:0', dtype=torch.float16)

# many-dimensional data supported; interpolation is performed row-wise
start: FloatTensor = tensor([[1,0], [2,0]], dtype=dtype, device=device)
end: FloatTensor = tensor([[0,1], [0,2]], dtype=dtype, device=device)
slerp(start, end, 0.5)
# tensor([[0.7070, 0.7070],
#         [1.4141, 1.4141]], device='cuda:0', dtype=torch.float16)

# any row where either vector has all zeroes, is computed via fallback to lerp():
start: FloatTensor = tensor([[0,0], [1,0]], dtype=dtype, device=device)
end: FloatTensor = tensor([[1,1], [0,1]], dtype=dtype, device=device)
slerp(start, end, time)
# tensor([[0.5000, 0.5000],
#         [0.7070, 0.7070]], device='cuda:0', dtype=torch.float16)

# any row where both vectors are approx. colinear, is computed via fallback to lerp():
start: FloatTensor = tensor([1,1], dtype=dtype, device=device)
end: FloatTensor = tensor([2,2], dtype=dtype, device=device)
slerp(start, end, 0.5)
# tensor([1.5000, 1.5000], device='cuda:0', dtype=torch.float16)

# we can compute slerp at a batch of timepoints simultaneously,
# by providing a time tensor that is more-dimensional than start/end
start: FloatTensor = tensor([1,0], dtype=dtype, device=device)
end: FloatTensor = tensor([0,1], dtype=dtype, device=device)
time: FloatTensor = tensor([[0.25], [0.5]], dtype=torch.float16, device=device)
slerp(start, end, time)
# tensor([[0.9238, 0.3826],
#         [0.7070, 0.7070]], device='cuda:0', dtype=torch.float16)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment