Skip to content

Instantly share code, notes, and snippets.

@speezepearson
Created April 7, 2014 23:30
Show Gist options
  • Save speezepearson/10073302 to your computer and use it in GitHub Desktop.
Save speezepearson/10073302 to your computer and use it in GitHub Desktop.
A module to deal with n-dimensional positions and vectors.
import numpy
from math import cos, sin, atan2, sqrt, pi
def floateq(x, y, tolerance=1e-10):
"""Tells if two floats are within a given tolerance of each other."""
return abs(x-y) < tolerance
class Vector(numpy.ndarray):
"""An n-dimensional vector."""
def __new__(cls, coords, polar=False):
# Return an array large enough to hold all the coordinates.
# (For the uninitiated: __new__ is a classmethod. When you say
# v = Vector((1, 2, 3))
# that's roughly equivalent to
# v = Vector.__new__((1, 2, 3))
# v.__init__((1, 2, 3))
# )
return numpy.ndarray.__new__(cls, shape=(len(coords),),
dtype=numpy.float)
def __init__(self, coordinates, polar=False):
if polar:
if len(coordinates) in (2, 3):
self[:] = polar_to_rectangular(coordinates)
else:
raise ValueError("can only use polar coordinates in 2d or 3d")
else:
self[:] = coordinates[:]
def __str__(self):
# >>> print Vector((1,2,3))
# <1.0, 2.0, 3.0>
list_version = str(list(self))
return '<' + list_version[1:-1] + '>'
def __repr__(self):
# >>> Vector((1,2,3))
# Vector([1.0, 2.0, 3.0])
return "Vector({0!r})".format(list(self))
def __hash__(self):
return hash(sum(self))
def __eq__(self, other):
# If all our coordinates are roughly equal to theirs,
# we say we're equal.
if isinstance(other, Vector):
return bool(floateq(self, other).all())
return NotImplemented
def _getx(self):
return self[0]
def _setx(self, x):
self[0] = x
def _gety(self):
return self[1]
def _sety(self, y):
self[1] = y
def _getz(self):
return self[2]
def _setz(self, z):
self[2] = z
x = property(_getx, _setx)
y = property(_gety, _sety)
z = property(_getz, _setz)
@property
def r(self):
return sqrt(sum(coord**2 for coord in self))
@property
def theta(self):
if len(self) == 2:
return atan2(self.y, self.x)
elif len(self) == 3:
return atan2(sqrt(self.x**2 + self.y**2), self.z)
else:
raise ValueError("theta attribute only valid for 2d or 3d vectors")
@property
def phi(self):
if len(self) == 3:
return atan2(self.y, self.x)
else:
raise ValueError("phi attribute only defined for 3d vectors")
def __nonzero__(self):
"""Returns whether any coordinates are nonzero.
Examples:
>>> assert Vector((0, 0))
Traceback (most recent call last):
...
AssertionError
>>> assert Vector((0, -1))
"""
return not floateq(self.view(numpy.ndarray), 0).all()
def dot(self, other):
"""Returns the dot product with another Vector.
Raises ValueError if the vectors are different lengths.
Examples:
>>> assert Vector((1,2,3)).dot(Vector((0,-2,1))) == -1
>>> assert Vector((3,)).dot(Vector((-1.5,))) == -4.5
"""
return numpy.vdot(self, other)
def cross(self, other):
"""Returns the cross product with another Vector.
Raises ValueError if either vector is not three-dimensional.
Examples:
>>> assert Vector((1,2,3)).cross(Vector((0,-2,1))) == Vector((8,-1,-2))
>>> Vector((2,3)).cross(Vector((0,1)))
Traceback (most recent call last):
...
ValueError: can only cross 3-vectors
"""
if len(self) == len(other) == 3:
result = self.copy()
result[:] = (self.y*other.z - self.z*other.y,
self.z*other.x - self.x*other.z,
self.x*other.y - self.y*other.x)
return result
else:
raise ValueError("can only cross 3-vectors")
def magsquared(self):
"""Returns the square of the magnitude of the vector."""
return self.dot(self)
def magnitude(self):
"""Returns the magnitude of the Vector.
Examples:
>>> assert floateq(Vector((1,-5,4)).magnitude(), (1+25+16)**.5)
>>> assert Vector((0,0,0)).magnitude() == 0
"""
return sqrt(self.magsquared())
def unit_vector(self):
"""Returns a unit vector in the direction of the Vector.
If the vector is the zero vector, acts as numpy does when dividing
a float by zero.
Examples:
>>> assert Vector((1,0,0)).unit_vector() == Vector((1,0,0))
>>> v = Vector((5, 3, -20))
>>> assert v.unit_vector() == v / v.magnitude()
>>> _ = numpy.seterr(all='raise')
>>> Vector((0,0)).unit_vector()
Traceback (most recent call last):
...
FloatingPointError: invalid value encountered in divide
"""
return self / self.magnitude()
def cos_with(self, other):
"""Returns the cosine of the angle made with another Vector.
Examples:
>>> assert floateq(Vector((1,0,0)).cos_with(Vector((0,1,0))), 0)
>>> assert floateq(Vector((1,0,0)).cos_with(Vector((-3,0,0))), -1)
"""
result = self.dot(other) / (self.magnitude()*other.magnitude())
# Sometimes, due to floating-point error, we get something not
# between 1 and -1. Let's fix that.
if result > 1:
return 1
elif result < -1:
return -1
return result
def sin_with(self, other):
"""Returns the sine of the angle made with another Vector."""
result = sqrt(1-self.cos_with(other)**2)
if result > 1:
return 1
elif result < -1:
return -1
return result
def component_parallel_to(self, other):
"""Returns the component of the vector parallel to the given vector.
Example:
>>> v1 = Vector((1, 1))
>>> v2 = Vector((3, 0))
>>> assert v1.component_parallel_to(v2) == Vector((1, 0))
>>> assert v2.component_parallel_to(v1) == Vector((1.5, 1.5))
"""
return other * self.dot(other)/other.dot(other)
def component_perpendicular_to(self, other):
"""Returns the component perpendicular to the given vector.
Example:
>>> v1 = Vector((1,1))
>>> v2 = Vector((3,0))
>>> assert v1.component_perpendicular_to(v2) == Vector((0, 1))
>>> assert v2.component_perpendicular_to(v1) == Vector((1.5, -1.5))
"""
return self - self.component_parallel_to(other)
def is_scalar_multiple_of(self, other):
"""Returns whether the vector is a scalar multiple of another.
Examples:
>>> v0 = Vector((0,0,0))
>>> v1 = Vector((1,0,0))
>>> v2 = Vector((2,3,4))
>>> v3 = Vector((4,6,8))
>>> assert not v3.is_scalar_multiple_of(v1)
>>> assert v3.is_scalar_multiple_of(v2)
>>> assert v2.is_scalar_multiple_of(v3)
>>> assert not v1.is_scalar_multiple_of(v0)
>>> assert v0.is_scalar_multiple_of(v1)
"""
if other:
perp = self.component_perpendicular_to(other)
return floateq(perp.dot(perp), 0)
else:
return self == other
def get_scalar_multiple_of(self, other):
"""Returns what multiple the vector is of the given vector.
If the vector is not a scalar multiple of the given vector,
the behavior is undefined. If the given vector is the zero
vector, acts as NumPy does when dividing by zero.
Examples:
>>> v0 = Vector((0,0,0))
>>> v1 = Vector((1,0,0))
>>> v2 = Vector((2,3,4))
>>> v3 = Vector((4,6,8))
>>> assert v3.get_scalar_multiple_of(v2) == 2
>>> assert v2.get_scalar_multiple_of(v3) == 0.5
>>> assert v0.get_scalar_multiple_of(v1) == 0
>>> _ = numpy.seterr(all='raise')
>>> v1.get_scalar_multiple_of(v0)
Traceback (most recent call last):
...
FloatingPointError: invalid value encountered in double_scalars
>>> v0.get_scalar_multiple_of(v0)
Traceback (most recent call last):
...
FloatingPointError: invalid value encountered in double_scalars
"""
return self.dot(other) / other.dot(other)
def rotated(self, angle, direction=None):
"""Returns the Vector rotated by the given angle.
If two-dimensional, it's rotated counterclockwise by the given angle.
If three-dimensional, it's rotated counterclockwise by the given angle
around the given direction vector.
Examples:
>>> Vector((1,2)).rotated(pi/2) == Vector((-2,1))
True
>>> Vector((0,0,1)).rotated(pi/2, Vector((1,0,0))) == Vector((0,-1,0))
True
>>> Vector((1,1,1)).rotated(pi, Vector((-1,-1,0)).unit_vector()) == Vector((1,1,-1))
True
"""
if len(self) == 2:
c,s = cos(angle), sin(angle)
return Vector((c*self.x-s*self.y, c*self.y+s*self.x))
elif len(self) == 3:
if direction is None:
raise ValueError("direction for 3D rotation not given")
parallel = self.component_parallel_to(direction)
perpendicular = self - parallel
crossed = direction.unit_vector().cross(perpendicular)
return parallel + perpendicular*cos(angle) + crossed*sin(angle)
else:
raise ValueError("can only rotate 2D or 3D vectors")
class Position(numpy.ndarray):
"""A position in n-space, built for compatibility with Vectors."""
def __new__(cls, coords, polar=False):
return numpy.ndarray.__new__(cls, dtype=numpy.float,
shape=(len(coords),))
def __init__(self, coords, polar=False):
if polar:
if len(coords) in (2, 3):
self[:] = polar_to_rectangular(coords)
else:
raise ValueError("can only use polar coordinates in 2d or 3d")
else:
self[:] = coords[:]
def __str__(self):
return "Position{0}".format(tuple(self))
def __repr__(self):
return "Position({0!r})".format(self.view(numpy.ndarray))
def __hash__(self):
return hash(sum(self))
def __eq__(self, other):
if isinstance(other, Position):
return floateq(self.view(numpy.ndarray),
other.view(numpy.ndarray)).all()
return NotImplemented
def __ne__(self, other):
return not self==other
def __add__(self, other):
if isinstance(other, Vector):
return Position(self.view(numpy.ndarray) +\
other.view(numpy.ndarray))
return result
return NotImplemented
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
"""Returns a Vector or Position subtracted from the Position.
Examples:
>>> assert Position((1,2,3)) - Vector((-1,-2,-3)) == Position((2,4,6))
>>> assert Position((2,4,6)) - Position((1,2,3)) == Vector((1,2,3))
"""
if isinstance(other, Position):
return Vector(self.view(numpy.ndarray) - other.view(numpy.ndarray))
if isinstance(other, Vector):
return self + (-other)
return NotImplemented
def moved(self, displacement):
return self+displacement
def scaled(self, factor, center):
return center + (self-center)*factor
def rotated(self, angle, center, direction=None):
"""Returns the point rotated by an angle around a center/axis.
Examples:
>>> assert Position((0,4)).rotated(pi/2, Position((0,0))) == Position((-4,0))
>>> assert Position((0,4)).rotated(pi/2, Position((0,2))) == Position((-2,2))
>>> assert Position((1,2,4)).rotated(pi, Position((0,0,0)), Vector((0,0,2))) == Position((-1,-2,4))
>>> assert Position((0,0,4)).rotated(pi, Position((0,0,2)), Vector((0,1,1))) == Position((0,2,2))
"""
if len(self) == 2:
if direction is not None:
raise ValueError("direction only needed for 3d rotations")
return center + (self-center).rotated(angle)
elif len(self) == 3:
if direction is None:
raise ValueError("direction is needed for 3d rotations")
return center + (self-center).rotated(angle, direction)
raise ValueError("only 2d and 3d positions can be rotated")
def _getx(self):
return self[0]
def _setx(self, x):
self[0] = x
def _gety(self):
return self[1]
def _sety(self, y):
self[1] = y
def _getz(self):
return self[2]
def _setz(self, z):
self[2] = z
x = property(_getx, _setx)
y = property(_gety, _sety)
z = property(_getz, _setz)
@property
def r(self):
return sqrt(sum(coord**2 for coord in self))
@property
def theta(self):
if len(self) == 2:
return atan2(self.y, self.x)
elif len(self) == 3:
return atan2(sqrt(self.x**2 + self.y**2), self.z)
else:
raise ValueError("theta attribute only valid for 2d or 3d positions")
@property
def phi(self):
if len(self) == 3:
return atan2(self.y, self.x)
else:
raise ValueError("phi attribute only defined for 3d positions")
def collinear(p1, p2, *others):
"""Returns whether all given Positions (at least 2) are collinear.
Examples:
>>> assert collinear(Position((1,0,0)), Position((0,1,0)),
... Position((-1,2,0)), Position((-2, 3, 0)))
"""
delta = p2-p1
for other in others:
d = other-p1
if not floateq(delta.dot(d)**2, delta.dot(delta) * d.dot(d)):
return False
return True
def dist(p1, p2):
"""Returns the distance between two Positions."""
return (p1-p2).magnitude()
def polar_to_rectangular(coordinates):
if len(coordinates) == 2:
r, theta = coordinates
return (r*cos(theta), r*sin(theta))
elif len(coordinates) == 3:
r, theta, phi = coordinates
return (r*sin(theta)*cos(phi),
r*sin(theta)*sin(phi),
r*cos(theta))
else:
raise ValueError("polar coordinates are only valid in 2d or 3d")
if __name__ == '__main__':
import unittest
import doctest
print "Testing docstrings..."
doctest.testmod()
print "Finished testing docstrings."
class TestVectors(unittest.TestCase):
def test_add(self):
self.assertEqual(Vector((1,)) + Vector((2,)),
Vector((3,)))
self.assertEqual(Vector((1, 4.1)) + Vector((-5.3, 3)),
Vector((-4.3, 7.1)))
self.assertEqual(Vector((1, 1, 1)) + Vector((-1, -1, -1)),
Vector((0, 0, 0)))
def test_subtract(self):
self.assertEqual(Vector((1,)) - Vector((2,)),
Vector((-1,)))
self.assertEqual(Vector((1, 4.1)) - Vector((-5.3, 3)),
Vector((6.3, 1.1)))
self.assertEqual(Vector((1, 1, 1)) - Vector((1, 1, 1)),
Vector((0, 0, 0)))
def test_dot(self):
self.assertAlmostEqual(Vector((1, 2, 3)).dot(Vector((0, -2, 1))), -1)
self.assertAlmostEqual(Vector((3,)).dot(Vector((-1.5,))), -4.5)
def test_cross(self):
self.assertEqual(Vector((1, 2, 3)).cross(Vector((0, -2, 1))),
Vector((8, -1, -2)))
with self.assertRaises(ValueError):
Vector((2, 3)).cross(Vector((0,1)))
def test_magnitude(self):
self.assertAlmostEqual(Vector((1, -5, 4)).magnitude(),
sqrt(1 + (-5)**2 + 4**2))
self.assertAlmostEqual(Vector((0, 0, 0)).magnitude(), 0)
def test_unit_vector(self):
self.assertEqual(Vector((5, 3, -20)).unit_vector(),
Vector((5, 3, -20))/Vector((5, 3, -20)).magnitude())
with self.assertRaises(ArithmeticError):
Vector((0, 0)).unit_vector()
def test_trig(self):
v1 = Vector((1, 0, 0))
v2 = Vector((0, 1, 0))
v3 = Vector((0, 1, 0))
v4 = Vector((1, 0))
v5 = Vector((1, 1))
self.assertAlmostEqual(v1.cos_with(v2), 0)
self.assertAlmostEqual(v1.sin_with(v2), 1)
self.assertAlmostEqual(v1.cos_with(v3), 0)
self.assertAlmostEqual(v1.sin_with(v3), 1)
self.assertAlmostEqual(v4.cos_with(v5), sqrt(2)/2)
self.assertAlmostEqual(v4.sin_with(v5), sqrt(2)/2)
with self.assertRaises(ValueError):
v1.cos_with(v4)
with self.assertRaises(ArithmeticError):
Vector((1, 0)).cos_with(Vector((0, 0)))
with self.assertRaises(ArithmeticError):
Vector((0, 0)).cos_with(Vector((1, 0)))
def test_components(self):
v1 = Vector((1, 1))
v2 = Vector((3, 0))
self.assertEqual(v1.component_parallel_to(v2), Vector((1, 0)))
self.assertEqual(v2.component_parallel_to(v1), Vector((1.5, 1.5)))
self.assertEqual(v1.component_perpendicular_to(v2), Vector((0, 1)))
self.assertEqual(v2.component_perpendicular_to(v1),
Vector((1.5, -1.5)))
with self.assertRaises(ValueError):
v1.component_parallel_to(Vector((1, 2, 3)))
with self.assertRaises(ArithmeticError):
v1.component_parallel_to(Vector((0, 0)))
def test_multiples(self):
v0 = Vector((0,0,0))
v1 = Vector((1,0,0))
v2 = Vector((2,3,4))
v3 = Vector((4,6,8))
self.assertFalse(v3.is_scalar_multiple_of(v1))
self.assertTrue(v3.is_scalar_multiple_of(v2))
self.assertTrue(v2.is_scalar_multiple_of(v3))
self.assertTrue(v0.is_scalar_multiple_of(v1))
self.assertFalse(v1.is_scalar_multiple_of(v0))
self.assertTrue(v0.is_scalar_multiple_of(v0))
self.assertAlmostEqual(v3.get_scalar_multiple_of(v2), 2)
self.assertAlmostEqual(v2.get_scalar_multiple_of(v3), .5)
self.assertAlmostEqual(v2.get_scalar_multiple_of(v2), 1)
self.assertAlmostEqual(v0.get_scalar_multiple_of(v1), 0)
with self.assertRaises(FloatingPointError):
v1.get_scalar_multiple_of(v0)
with self.assertRaises(FloatingPointError):
v0.get_scalar_multiple_of(v0)
def test_rotated(self):
self.assertEqual(Vector((1, 2)).rotated(pi/2), Vector((-2, 1)))
self.assertEqual(Vector((0, 0, 1)).rotated(pi/2,
Vector((1, 0, 0))),
Vector((0, -1, 0)))
self.assertEqual(
Vector((1,1,1)).rotated(pi, Vector((-1,-1,0))),
Vector((1,1,-1)))
with self.assertRaises(ValueError):
Vector((1, 1, 1)).rotated(pi)
with self.assertRaises(ValueError):
Vector((1, 1, 1)).rotated(pi, Vector((1, 1)))
def test_conversions(self):
v1 = Vector((1, 0, 0))
v2 = Vector((0, 1, 0))
v3 = Vector((0, 1, 1))
v4 = Vector((0, 1, -1))
class TestPosition(unittest.TestCase):
def test_equality(self):
self.assertEqual(Position((1, 4, -1)), Position((1, 4, -1)))
self.assertNotEqual(Position((1, 4, -1)), Position((1, 4, -2)))
def test_rotated(self):
p2 = Position((1, 2))
p3 = Position((1, 2, 1))
self.assertEqual(p2.rotated(pi/2, Position((0, 0))),
Position((-2, 1)))
self.assertEqual(p2.rotated(pi, Position((0, 1))),
Position((-1, 0)))
self.assertEqual(p3.rotated(pi, Position((0, 1, 0)),
Vector((1, 0, 1))),
Position((1, 0, 1)))
def test_collinear(self):
self.assertTrue(collinear(Position((1,)), Position((5,)), Position((99,))))
self.assertTrue(collinear(Position((1, 0, 1)),
Position((2, 1, 2)),
Position((-5, -6, -5))))
self.assertFalse(collinear(Position((0, 0, 0)),
Position((10, -10, 5)),
Position((100, -100, 51))))
def test_dist(self):
self.assertAlmostEqual(dist(Position((1, 2, 3)),
Position((4, -1, -4))),
sqrt(3**2 + 3**2 + 7**2))
self.assertAlmostEqual(dist(Position((1,)), Position((-410,))),
411)
unittest.main()
@Baxayesh
Copy link

It was helpful for me, Thank you❤️❤️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment