Skip to content

Instantly share code, notes, and snippets.

@aormorningstar
Last active January 12, 2024 21:24
Show Gist options
  • Save aormorningstar/3e5dda91f155d7919ef6256cb057ceee to your computer and use it in GitHub Desktop.
Save aormorningstar/3e5dda91f155d7919ef6256cb057ceee to your computer and use it in GitHub Desktop.
Find the rotation matrix that aligns one three-dimensional vector with another.
import numpy as np
def rotation(v1, v2):
"""
Compute a matrix R that rotates v1 to align with v2.
v1 and v2 must be length-3 1d numpy arrays.
"""
# unit vectors
u = v1 / np.linalg.norm(v1)
Ru = v2 / np.linalg.norm(v2)
# dimension of the space and identity
dim = u.size
I = np.identity(dim)
# the cos angle between the vectors
c = np.dot(u, Ru)
# a small number
eps = 1.0e-10
if np.abs(c - 1.0) < eps:
# same direction
return I
elif np.abs(c + 1.0) < eps:
# opposite direction
return -I
else:
# the cross product matrix of a vector to rotate around
K = np.outer(Ru, u) - np.outer(u, Ru)
# Rodrigues' formula
return I + K + (K @ K) / (1 + c)
import numpy as np
from rotation import rotation
from unittest import TestCase
# To run tests run "python -m unittest" from the command line.
class TestRotation(TestCase):
"""Test the rotation function."""
def setUp(self):
self.dim = 3
def test_random_vectors(self):
# number of samples
ns = 100
no_problems = True
for _ in range(ns):
# random vectors
v1 = np.random.randn(self.dim)
v2 = np.random.randn(self.dim)
# norms
n1 = np.linalg.norm(v1)
n2 = np.linalg.norm(v2)
# rotation
R = rotation(v1, v2)
Rv1 = R @ v1
# check for correctness of dot product
if not np.isclose(np.dot(v2, Rv1), n1 * n2):
no_problems = False
break
self.assertTrue(no_problems)
def test_aligned(self):
# one random vector
v = np.random.randn(self.dim)
# norm
n = np.linalg.norm(v)
# rotation
R = rotation(v, v)
Rv = R @ v
self.assertAlmostEqual(np.dot(v, Rv), n**2)
def test_antialigned(self):
# one random vector
v = np.random.randn(self.dim)
# norm
n = np.linalg.norm(v)
# rotation
R = rotation(v, -v)
Rv = R @ v
self.assertAlmostEqual(np.dot(-v, Rv), n**2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment