Last active
November 8, 2019 06:21
-
-
Save OlegZero13/8e2ca12e67ffd9b4bf020cc1a9b4f215 to your computer and use it in GitHub Desktop.
Model of a quaternion
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from math import sin, cos, atan2, asin, sqrt | |
class Quaternion: | |
def __init__(self, w, x, y, z): | |
self.w = w | |
self.x = x | |
self.y = y | |
self.z = z | |
@classmethod | |
def create_from_ypr(cls, yaw, pitch, roll): | |
r = cls._ypr_to_coords(yaw, pitch, roll) | |
return cls(*r) | |
@classmethod | |
def create_identity(cls): | |
return cls(1, 0, 0, 0) | |
def coords(self): | |
return self.w, self.x, self.y, self.z | |
def _ypr_to_coords(yaw, pitch, roll): | |
y = 0.5 * yaw | |
p = 0.5 * pitch | |
r = 0.5 * roll | |
w = cos(y) * cos(p) * cos(r) + sin(y) * sin(p) * sin(r) | |
x = cos(y) * cos(p) * sin(r) - sin(y) * sin(p) * cos(r) | |
y = sin(y) * cos(p) * sin(r) + cos(y) * sin(p) * cos(r) | |
z = sin(y) * cos(p) * cos(r) - cos(y) * sin(p) * sin(r) | |
return w, x, y, z | |
def _coords_to_ypr(w, x, y, z): | |
t0 = 2.0 * (w * x + y * z) | |
t1 = 1.0 - 2.0 * (x * x + y * y) | |
roll = atan2(t0, t1) | |
t2 = 2.0 * (w * y - z * x) | |
t2 = 1.0 if t2 > 1.0 else t2 | |
t2 = -1.0 if t2 < -1.0 else t2 | |
pitch = asin(t2) | |
t3 = 2.0 * (w * z - x * y) | |
t4 = 1.0 - 2.0 * (y * y - z * z) | |
yaw = atan2(t3, t4) | |
return yaw, pitch, roll | |
def __repr__(self): | |
return "Quaternion({}, {}, {}, {})".format( | |
self.w, self.x, self.y, self.z) | |
def __str__(self): | |
return "Q = {:.2f} + {:.2f}i + {:.2f}j + {:.2f}k".format( | |
self.w, self.x, self.y, self.z) | |
def __add__(self, other): | |
r = list(map(lambda i, j: i + j, self.coords(), other.coords())) | |
return Quaternion(*r) | |
def __sub__(self, other): | |
r = list(map(lambda i, j: i - j, self.coords(), other.coords())) | |
return Quaternion(*r) | |
def __mul__(self, other): | |
if isinstance(other, Quaternion): | |
w = self.w * other.w - self.x * other.x - self.y * other.y - self.z * other.z | |
x = self.w * other.x + self.x * other.w + self.y * other.z - self.z * other.y | |
y = self.w * other.y + self.y * other.w + self.z * other.x - self.x * other.z | |
z = self.w * other.z + self.z * other.w + self.x * other.y - self.y * other.x | |
return Quaternion(w, x, y, z) | |
elif isinstance(other, (int, float)): | |
coords = [other * i for i in self.coords()] | |
return Quaternion(*coords) | |
else: | |
raise TypeError("Operation undefined.") | |
def __rmul__(self, other): | |
if isinstance(other, (int, float)): | |
coords = [other * i for i in self.coords()] | |
return Quaternion(*coords) | |
else: | |
raise TypeError("Operation undefined.") | |
def __matmul__(self, other): | |
r = list(map(lambda i, j: i * j, self.coords(), other.coords())) | |
return Quaternion(*r) | |
def __eq__(self, other): | |
r = list(map(lambda i, j: abs(i) == abs(j), self.coords(), other.coords())) | |
return sum(r) == len(r) | |
def almost_equal(self, other, eps=1e-1): | |
r = list(map(lambda i, j: abs(i - j) < eps, self.coords(), other.coords())) | |
return sum(r) == len(r) | |
def norm(self): | |
return sqrt(sum([i**2 for i in self.coords()])) | |
def conjugate(self): | |
x, y, z = -self.x, -self.y, -self.z | |
return Quaternion(self.w, x, y, z) | |
def normalize(self): | |
coords = [i / self.norm() for i in self.coords()] | |
return Quaternion(*coords) | |
def inverse(self): | |
q0 = self.conjugate() | |
coords = [i / self.norm() for i in q0.coords()] | |
return Quaternion(*coords) | |
def dot(self, other): | |
return self.__matmul__(other) | |
def distance(self, other): | |
q1 = self.normalize() | |
q2 = other.normalize().inverse() | |
return q1 * q2 | |
if __name__ == '__main__': | |
q1 = Quaternion.create_from_ypr(0.1, 0.2, 0.3) | |
ypr2 = Quaternion._coords_to_ypr(*q1.coords()) | |
q2 = Quaternion.create_from_ypr(*ypr2) | |
assert q1.almost_equal(q2) | |
assert q1 == q1.conjugate() | |
assert q1 == q1.inverse().inverse() | |
assert q1.distance(q1) == Quaternion.create_identity() | |
assert q1.distance(q2).almost_equal(q2.distance(q1)) | |
assert q1.conjugate().distance(q1) == q1.distance(q1.conjugate()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment