Skip to content

Instantly share code, notes, and snippets.

@davegreenwood
Created January 13, 2021 14:08
Show Gist options
  • Save davegreenwood/90619e569d8ba0cc37b727ba8b77de00 to your computer and use it in GitHub Desktop.
Save davegreenwood/90619e569d8ba0cc37b727ba8b77de00 to your computer and use it in GitHub Desktop.
difference between rotation matrices
# %%
import torch
import torch.nn.functional as F
# compute rotation vector to get from a to b
def angleBetweenRT(a, b):
# find rotation axis
rotvec = torch.cross(a, b)
# normalise axis
x, y, z = F.normalize(rotvec, dim=0)
# find angle
theta = torch.acos(torch.dot(a, b))
cos_angle = torch.cos(theta)
sin_angle = torch.sin(theta)
# construct rotation matrix from axis-angle representation
R = torch.zeros(3, 3, dtype=torch.float64)
R[0, 0] = cos_angle + x*x*(1-cos_angle)
R[1, 0] = z * sin_angle + y*x*(1-cos_angle)
R[2, 0] = -y * sin_angle + z*x*(1-cos_angle)
R[0, 1] = -z * sin_angle + x*y*(1-cos_angle)
R[1, 1] = cos_angle + y*y*(1-cos_angle)
R[2, 1] = x * sin_angle + z*y*(1-cos_angle)
R[0, 2] = y * sin_angle + x*z*(1-cos_angle)
R[1, 2] = -x * sin_angle + y*z*(1-cos_angle)
R[2, 2] = cos_angle + z*z*(1-cos_angle)
return R
# compute rotation vector to get from a to b (a, b must be unit!!)
def _angleBetweenRT(a, b):
# find rotation axis
rotvec = torch.cross(a, b)
# normalise axis
x, y, z = F.normalize(rotvec, dim=0)
# find angle in a different way!!!
# theta = torch.acos(torch.dot(a, b))
# cos_angle = torch.cos(theta)
# sin_angle = torch.sin(theta)
cos_angle = a @ b
sin_angle = torch.sqrt(1.0 - cos_angle ** 2)
# construct rotation matrix from axis-angle representation
R = torch.zeros(3, 3, dtype=torch.float32)
R[0, 0] = cos_angle + x*x*(1-cos_angle)
R[1, 0] = z * sin_angle + y*x*(1-cos_angle)
R[2, 0] = -y * sin_angle + z*x*(1-cos_angle)
R[0, 1] = -z * sin_angle + x*y*(1-cos_angle)
R[1, 1] = cos_angle + y*y*(1-cos_angle)
R[2, 1] = x * sin_angle + z*y*(1-cos_angle)
R[0, 2] = y * sin_angle + x*z*(1-cos_angle)
R[1, 2] = -x * sin_angle + y*z*(1-cos_angle)
R[2, 2] = cos_angle + z*z*(1-cos_angle)
return R
for i in range(10):
a = F.normalize(torch.randn(3), dim=0)
b = F.normalize(torch.randn(3), dim=0)
r = angleBetweenRT(a, b)
_r = _angleBetweenRT(a, b)
print((r - _r))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment