Skip to content

Instantly share code, notes, and snippets.

@leogao2
Forked from dvschultz/pytorch-tensor-slerp.py
Created February 4, 2023 00:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save leogao2/53544f822a5f133d1d4b5bc54c8f8838 to your computer and use it in GitHub Desktop.
Save leogao2/53544f822a5f133d1d4b5bc54c8f8838 to your computer and use it in GitHub Desktop.
# modified from https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c
def slerp(t, v0, v1):
'''
Spherical linear interpolation (batched)
Args:
t (float/np.ndarray): Float value between 0.0 and 1.0
v0 (np.ndarray): Starting vector
v1 (np.ndarray): Final vector
Returns:
v2 (np.ndarray): Interpolation vector between v0 and v1
'''
c = False
if not isinstance(v0,np.ndarray):
c = True
v0 = v0.detach().cpu().numpy()
if not isinstance(v1,np.ndarray):
c = True
v1 = v1.detach().cpu().numpy()
# Copy the vectors to reuse them later
v0_copy = np.copy(v0)
v1_copy = np.copy(v1)
# Normalize the vectors to get the directions and angles
v0 = v0 / np.linalg.norm(v0, axis=-1)[:,:,None]
v1 = v1 / np.linalg.norm(v1, axis=-1)[:,:,None]
# Dot product with the normalized vectors (can't use np.dot in W)
dot = np.sum(v0 * v1, axis=-1)
# Calculate initial angle between v0 and v1
theta_0 = np.arccos(dot)
print(theta_0.shape)
sin_theta_0 = np.sin(theta_0)
# Angle at timestep t
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
# Finish the slerp algorithm
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0[:,:,None] * v0_copy + s1[:,:,None] * v1_copy
if c:
res = torch.from_numpy(v2).to("cuda")
else:
res = v2
return res
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment