Skip to content

Instantly share code, notes, and snippets.

@7shi
Last active June 25, 2024 14:33
Show Gist options
  • Save 7shi/912ad67f3edd07278b61ecb2fb0acf26 to your computer and use it in GitHub Desktop.
Save 7shi/912ad67f3edd07278b61ecb2fb0acf26 to your computer and use it in GitHub Desktop.
[py] calculate quaternions and octonions
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
import sys
def is_number(n):
return isinstance(n, (int, float))
def conv_n(n):
return N(n) if is_number(n) else n
def conv_ns(ns):
return [conv_n(n) for n in ns]
class N:
def __init__(self, value):
self.value = value
def __str__(self):
return f"N({repr(self.value)})"
def __repr__(self):
return to_string(self)
def __add__(self, other):
return Sum(self, conv_n(other))
def __radd__(self, other):
return Sum(conv_n(other), self)
def __neg__(self):
return -1 * self
def __sub__(self, other):
return self + (-conv_n(other))
def __rsub__(self, other):
return conv_n(other) + (-self)
def __mul__(self, other):
return Product(self, conv_n(other))
def __rmul__(self, other):
return Product(conv_n(other), self)
class Sum:
def __init__(self, *values):
if len(values) == 1 and hasattr(values[0], '__iter__'):
self.values = list(values[0])
else:
self.values = list(values)
def __str__(self):
return f"Sum({', '.join(map(str, self.values))})"
def __repr__(self):
return to_string(self)
def __add__(self, other):
if isinstance(other, Sum):
return Sum(*self.values, *other.values)
return Sum(*self.values, conv_n(other))
def __radd__(self, other):
if isinstance(other, Sum):
return Sum(*other.values, *self.values)
return Sum(conv_n(other), *self.values)
def __neg__(self):
return -1 * self
def __sub__(self, other):
return self + (-conv_n(other))
def __rsub__(self, other):
return conv_n(other) + (-self)
def __mul__(self, other):
return Product(self, conv_n(other))
def __rmul__(self, other):
return Product(conv_n(other), self)
class Product:
def __init__(self, *values):
if len(values) == 1 and hasattr(values[0], '__iter__'):
self.values = list(values[0])
else:
self.values = list(values)
def __str__(self):
return f"Product({', '.join(map(str, self.values))})"
def __repr__(self):
return to_string(self)
def __add__(self, other):
return Sum(self, conv_n(other))
def __radd__(self, other):
return Sum(conv_n(other), self)
def __neg__(self):
if len(self.values) and eq_expr(self.values[0], N(-1)):
return Product(*self.values[1:])
return -1 * self
def __sub__(self, other):
return self + (-conv_n(other))
def __rsub__(self, other):
return conv_n(other) + (-self)
def __mul__(self, other):
if isinstance(other, Product):
return Product(*self.values, *other.values)
return Product(*self.values, conv_n(other))
def __rmul__(self, other):
if isinstance(other, Product):
return Product(*other.values, *self.values)
return Product(conv_n(other), *self.values)
def syms(symbols_str):
symbols = symbols_str.split(',')
return tuple(N(sym) for sym in symbols)
# a~zをN("a"),...N("z")として定義
a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z = syms("a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x,y,z")
def to_string(expr):
if isinstance(expr, N):
return str(expr.value)
elif isinstance(expr, Sum):
ret = ""
for term in expr.values:
s = to_string(term)
if isinstance(term, Sum):
if ret:
ret += " + "
ret += f"({s})"
elif ret and s.startswith("-"):
ret += f" - {s[1:]}"
else:
if ret:
ret += " + "
ret += s
return ret
elif isinstance(expr, Product):
factors = expr.values
ret = ""
if len(factors) > 1 and eq_expr(factors[0], N(-1)):
ret += "-"
factors = factors[1:]
for i, factor in enumerate(factors):
if i:
ret += " * "
f = to_string(factor)
if isinstance(factor, (Sum, Product)) or (ret and f.startswith("-")):
ret += f"({f})"
else:
ret += f
return ret
return str(expr)
# テスト関数
def test_to_string():
print("Testing improved to_string function:")
# 基本的な数値
n = N(5)
assert to_string(n) == "5", "Basic number test failed"
# 単純な積
p = Product(N(2), N('x'))
assert to_string(p) == "2 * x", "Simple product test failed"
# ネストされた積
np = Product(N(2), Product(N('x'), N('y')))
assert to_string(np) == "2 * (x * y)", "Nested product test failed"
# 複数のネストされた積
nnp = Product(N(2), Product(N('x'), N('y')), Product(N('a'), N('b')))
assert to_string(nnp) == "2 * (x * y) * (a * b)", "Multiple nested products test failed"
# 和
s = Sum(N(1), N(2))
assert to_string(s) == "1 + 2", "Simple sum test failed"
# 積を含む和
sp = Sum(N(1), Product(N(2), N('x')))
assert to_string(sp) == "1 + 2 * x", "Sum with product test failed"
# 和を含む積
ps = Product(N(2), Sum(N('x'), N('y')))
assert to_string(ps) == "2 * (x + y)", "Product with sum test failed"
print("All to_string tests passed!")
def product_to_sum(xs, ys):
from itertools import product
result = []
for x, y in product(xs, ys):
fs = []
if isinstance(x, Product):
fs += x.values
else:
fs.append(x)
if isinstance(y, Product):
fs += y.values
else:
fs.append(y)
result.append(Product(*fs))
return result
def expand(expr):
if isinstance(expr, Sum):
return Sum(expand(term) for term in expr.values)
elif isinstance(expr, Product):
sums = []
for factor in expr.values:
factor = expand(factor)
if isinstance(factor, Sum):
if not sums:
sums = factor.values
else:
sums = product_to_sum(sums, factor.values)
elif not sums:
sums.append(factor)
else:
sums = product_to_sum(sums, [factor])
return Sum(*sums) if len(sums) != 1 else sums[0]
return expr
def run_tests1():
test_cases = [
('N(2) * (N(3) + N(4))', '2 * 3 + 2 * 4'),
('(N(1) + N(2)) * (N(3) + N(4))', '1 * 3 + 1 * 4 + 2 * 3 + 2 * 4'),
('N(2) * N(3) * (N(4) + N(5))', '2 * 3 * 4 + 2 * 3 * 5'),
('N(1) + N(2) * (N(3) + N(4))', '1 + 2 * 3 + 2 * 4'),
('(a + b) * (c + d)', 'a * c + a * d + b * c + b * d'),
('x * (N(1) + N(2))', 'x * 1 + x * 2'),
('a + b * (c + d)', 'a + b * c + b * d'),
('f * (g + h)', 'f * g + f * h'),
('N(1) * 2 + 3', '1 * 2 + 3'),
('2 * N(3) + 4', '2 * 3 + 4'),
('(N(1) + 2) * (3 + N(4))', '1 * 3 + 1 * 4 + 2 * 3 + 2 * 4'),
]
for input_expr, expected in test_cases:
expr = eval(input_expr)
expanded = flatten(expand(expr))
result = to_string(expanded)
passed = result == expected
print(f"{'✅' if passed else '❌'} {input_expr} => {result}")
def flatten(expr):
if isinstance(expr, (Sum, Product)):
result = []
for x in expr.values:
x = flatten(x)
if isinstance(x, type(expr)):
result += x.values
else:
result.append(x)
return type(expr)(*result) if len(result) != 1 else result[0]
return expr
def eq_expr(expr1, expr2):
if type(expr1) != type(expr2):
return False
if isinstance(expr1, N):
return expr1.value == expr2.value
elif isinstance(expr1, Sum):
if len(expr1.values) != len(expr2.values):
return False
return all(eq_expr(t1, t2) for t1, t2 in zip(sorted(expr1.values, key=lambda x: str(x)), sorted(expr2.values, key=lambda x: str(x))))
elif isinstance(expr1, Product):
if len(expr1.values) != len(expr2.values):
return False
return all(eq_expr(f1, f2) for f1, f2 in zip(expr1.values, expr2.values))
else:
return expr1 == expr2
def forward(expr, *values):
values = conv_ns(values)
def rec(expr):
if isinstance(expr, Sum):
return Sum(rec(term) for term in expr.values)
if isinstance(expr, Product):
nums = []
others = []
for factor in expr.values:
if values and any(eq_expr(factor, value) for value in values):
nums.append(factor)
elif not values and isinstance(factor, N) and is_number(factor.value):
nums.append(factor)
else:
others.append(rec(factor))
return Product(*nums, *others)
return expr
return rec(expr)
def backward(expr, *values):
values = conv_ns(values)
def rec(expr):
if isinstance(expr, Sum):
return Sum(rec(term) for term in expr.values)
if isinstance(expr, Product):
others = []
backs = []
for factor in expr.values:
if any(eq_expr(factor, value) for value in values):
backs.append(factor)
else:
others.append(rec(factor))
return Product(*others, *backs)
return expr
return rec(expr)
def find(nodes1, nodes2):
for i in range(len(nodes1) - len(nodes2) + 1):
if all(eq_expr(nodes1[i + j], nodes2[j]) for j in range(len(nodes2))):
return i
return -1
def replace(expr, pattern, replacement):
replacement = conv_n(replacement)
def rec(expr):
if eq_expr(expr, pattern):
return replacement
if isinstance(expr, (Sum, Product)):
vs1 = expr.values
if isinstance(expr, type(pattern)):
same = isinstance(pattern, type(replacement))
vs2 = pattern.values
result = []
i = 0
last = len(vs1) - len(vs2)
while i < len(vs1):
if i <= last and all(eq_expr(vs1[i + j], v) for j, v in enumerate(vs2)):
if same:
result += replacement.values
else:
result.append(replacement)
i += len(vs2)
else:
result.append(rec(vs1[i]))
i += 1
else:
result = [rec(e) for e in vs1]
return type(expr)(*result) if len(result) != 1 else result[0]
return expr
return rec(expr)
# テスト関数
def test_replace():
I, J, K = syms("I,J,K")
# テストケース1: 単純な置換
expr1 = Product(I, J)
result1 = replace(expr1, Product(I, J), K)
print(f"IJ replaced: {to_string(result1)}")
assert to_string(result1) == "K", "Simple replacement failed"
# テストケース2: ネストされたProductのフラット化1
expr2 = Product(I, Product(J, K), I)
result2 = replace(expr2, Product(J, K), I)
print(f"I(JK)I replaced: {to_string(result2)}")
assert to_string(result2) == "I * I * I", "Nested Product flattening failed"
# テストケース3: ネストされたProductのフラット化2
expr3 = Product(K, J, I)
result3 = replace(expr3, Product(J, I), Product(N(-1), K))
print(f"KJI double replaced: {to_string(result3)}")
assert to_string(result3) == "K * (-1) * K", "Nested Product flattening failed"
print("All replace tests passed!")
def run_tests2():
print("Running tests...")
# flatten のテスト
print("\nTesting flatten:")
a, b, c = syms("a,b,c")
expr1 = Product(a, Product(b, c))
flattened1 = flatten(expr1)
print(f"Original: {to_string(expr1)}")
print(f"Flattened: {to_string(flattened1)}")
assert to_string(flattened1) == "a * b * c", "flatten test 1 failed"
expr2 = Sum(Product(a, Product(b, c)), Product(Product(a, b), c))
flattened2 = flatten(expr2)
print(f"Original: {to_string(expr2)}")
print(f"Flattened: {to_string(flattened2)}")
assert to_string(flattened2) == "a * b * c + a * b * c", "flatten test 2 failed"
# back のテスト
print("\nTesting back:")
x, y, z = syms("x,y,z")
expr3 = Sum(Product(x, y, z), Product(y, x, z))
backed3 = backward(expr3, x)
print(f"Original: {to_string(expr3)}")
print(f"Backed: {to_string(backed3)}")
assert to_string(backed3) == "y * z * x + y * z * x", "back test failed"
# replace のテスト
print("\nTesting replace:")
i, x1, x2, y1, y2 = syms("i,x1,x2,y1,y2")
expr4 = (x1 + y1 * i) * (x2 + y2 * i)
expr4 = expand(expr4)
expr4 = backward(expr4, i)
replaced4 = replace(expr4, Product(i, i), N(-1))
print(f"Original: {to_string(expr4)}")
print(f"Replaced: {to_string(replaced4)}")
assert to_string(replaced4) == "x1 * x2 + x1 * y2 * i + y1 * x2 * i + y1 * y2 * (-1)", "replace test failed"
print("\nAll tests passed!")
def collect(expr, *symbols, forward=False, forward_other=False):
symbols = conv_ns(symbols)
def rec(expr):
if isinstance(expr, Product):
return Product(rec(factor) for factor in expr.values)
if isinstance(expr, Sum):
collected = {}
collected_syms = {}
other_terms = []
for term in expr.values:
term = rec(term)
if isinstance(term, Product):
others = []
syms = []
for factor in term.values:
if any(eq_expr(factor, sym) for sym in symbols):
syms.append(factor)
else:
others.append(factor)
if syms:
s = to_string(Product(*syms))
if s not in collected:
collected[s] = []
collected_syms[s] = syms
collected[s].append(Product(*others) if len(others) != 1 else others[0])
else:
other_terms.append(term)
else:
other_terms.append(term)
result = []
if forward_other:
result += other_terms
for s in sorted(collected):
vs = collected[s]
ks = Sum(*vs) if len(vs) != 1 else vs[0]
if forward:
result.append(Product(*collected_syms[s], ks))
else:
result.append(Product(ks, *collected_syms[s]))
if not forward_other:
result += other_terms
return Sum(*result) if len(result) != 1 else result[0]
return expr
return rec(expr)
def test_collect():
x1, x2, y1, y2, i = syms("x1,x2,y1,y2,i")
# テストケース1
expr = Sum(
Product(x1, x2),
Product(x1, y2, i),
Product(y1, x2, i),
Product(y1, y2, N(-1))
)
print("Test case 1:")
print("Original expression:")
print(to_string(expr))
collected = collect(expr, i)
print("\nAfter collecting terms with respect to i:")
print(to_string(collected))
expected = Sum(
Product(Sum(Product(x1, y2), Product(y1, x2)), i),
Product(x1, x2),
Product(y1, y2, N(-1))
)
assert to_string(collected) == to_string(expected), f"Test case 1 failed. Expected {to_string(expected)}, but got {to_string(collected)}"
print("Test case 1 passed.")
# テストケース2
expr2 = Sum(
Product(N(2), x1, i),
Product(N(3), x2, i),
Product(N(4), x1),
Product(N(5), x2)
)
print("\nTest case 2:")
print("Original expression:")
print(to_string(expr2))
collected2 = collect(expr2, i)
print("\nAfter collecting terms with respect to i:")
print(to_string(collected2))
expected2 = Sum(
Product(Sum(Product(N(2), x1), Product(N(3), x2)), i),
Product(N(4), x1),
Product(N(5), x2)
)
assert to_string(collected2) == to_string(expected2), f"Test case 2 failed. Expected {to_string(expected2)}, but got {to_string(collected2)}"
print("Test case 2 passed.")
print("\nAll test cases passed successfully!")
def simplify(expr):
expr = forward(flatten(expand(flatten(expand(expr)))))
terms = None
if isinstance(expr, Sum):
nterms = []
for term in expr.values:
n = 1
if isinstance(term, Product):
while len(term.values) and isinstance(n0 := term.values[0], N) and is_number(n0.value):
n *= n0.value
term = Product(*term.values[1:])
found = False
for t in nterms:
if eq_expr(t[1], term):
t[0] += n
found = True
break
if not found:
nterms.append([n, term])
terms = [t if n == 1 else n * t for n, t in nterms if n]
elif hasattr(expr, "values"):
terms = [simplify(x) for x in expr.values]
if terms is not None:
l = len(terms)
return N(0) if l == 0 else terms[0] if l == 1 else type(expr)(*terms)
return expr
commons = [
(N(1) * 1, N(1)),
(N(1) * -1, N(-1)),
(N(-1) * -1, N(1)),
]
# 四元数の基本単位を定義
I, J, K = syms("I,J,K")
quaternions = [I, J, K]
qreps = [*commons]
def init_q():
for q in quaternions:
qreps.append((N(1) * q, q))
qreps.append((q * q, N(-1)))
qreps.extend([
(I * J, K),
(J * I, N(-1) * K),
(J * K, I),
(K * J, N(-1) * I),
(K * I, J),
(I * K, N(-1) * J),
])
init_q()
def replace_q(expr):
while True:
new_expr = backward(forward(expr), *quaternions)
for pattern, replacement in qreps:
new_expr = replace(new_expr, pattern, replacement)
if eq_expr(new_expr, expr):
break
expr = new_expr
return expr
def conj_q(expr):
for q in quaternions:
expr = replace(expr, q, N(-1) * q)
return expr
# テスト関数
def test_replace_q():
print("Testing quaternion replacements:")
# テストケース1: IJ
expr1 = Product(I, J)
result1 = replace_q(expr1)
print(f"IJ = {to_string(result1)}")
assert to_string(result1) == "K", "IJ should be K"
# テストケース2: JI
expr2 = Product(J, I)
result2 = replace_q(expr2)
print(f"JI = {to_string(result2)}")
assert to_string(result2) == "-K", "JI should be -K"
# テストケース3: (I + J)K
expr3 = Product(Sum(I, J), K)
expanded3 = expand(expr3)
result3 = replace_q(expanded3)
print(f"(I + J)K = {to_string(result3)}")
assert to_string(result3) == "-J + I", "(I + J)K should be -J + I"
# テストケース4: I(J + K)
expr4 = Product(I, Sum(J, K))
expanded4 = expand(expr4)
result4 = replace_q(expanded4)
print(f"I(J + K) = {to_string(result4)}")
assert to_string(result4) == "K - J", "I(J + K) should be K - J"
# テストケース5: IJK
expr5 = Product(I, J, K)
result5 = replace_q(expr5)
print(f"IJK = {to_string(result5)}")
assert to_string(result5) == "-1", "IJK should be -1"
# テストケース6: IJKJI
expr6 = Product(I, J, K, J, I)
result6 = replace_q(expr6)
print(f"IJKJI = {to_string(result6)}")
assert to_string(result6) == "K", "IJKJI should be K"
# テストケース7: IIJJKK
expr7 = Product(I, I, J, J, K, K)
result7 = replace_q(expr7)
print(f"IIJJKK = {to_string(result7)}")
assert to_string(result7) == "-1", "IIJJKK should be -1"
print("All quaternion tests passed!")
# 八元数の基本単位を定義
e1, e2, e3, e4, e5, e6, e7 = syms("e1,e2,e3,e4,e5,e6,e7")
octonions = [e1, e2, e3, e4, e5, e6, e7]
triads = [
[1, 2, 3],
[1, 4, 5],
[1, 7, 6],
[2, 4, 6],
[2, 5, 7],
[3, 4, 7],
[3, 6, 5]
]
def mulo(a, b):
if a == 0:
return [1, b]
if b == 0:
return [1, a]
if a == b:
return [-1, 0]
for triad in triads:
if a in triad and b in triad:
c = next(x for x in triad if x != a and x != b)
s = 1 if (triad.index(b) - triad.index(a) + 3) % 3 == 1 else -1
return [s, c]
return [0, 0] # エラーケース
oreps = [*commons]
def init_o():
for i in range(1, 8):
o1 = octonions[i - 1]
oreps.append((N(1) * o1, o1))
for j in range(1, 8):
o2 = octonions[j - 1]
result = mulo(i, j)
if result[1] == 0:
oreps.append((o1 * o2, N(result[0])))
else:
o3 = octonions[result[1] - 1]
if result[0] == 1:
oreps.append((o1 * o2, o3))
else:
oreps.append((o1 * o2, N(result[0]) * o3))
init_o()
def replace_o(expr):
while True:
new_expr = backward(forward(expr), *octonions)
for pattern, replacement in oreps:
new_expr = replace(new_expr, pattern, replacement)
if eq_expr(new_expr, expr):
break
expr = new_expr
return expr
def conj_o(expr):
for o in octonions:
expr = replace(expr, Product(o), N(-1) * o)
return expr
def arrange_o(expr):
expr = replace_o(expr)
expr = collect(expr, *octonions, forward_other=True)
expr = forward(expr, *xs)
expr = forward(expr)
return expr
# テスト関数
def test_replace_o():
print("Testing octonion replacements:")
# テストケース1: e1 * e2
expr1 = Product(e1, e2)
result1 = replace_o(expr1)
print(f"e1 * e2 = {to_string(result1)}")
assert to_string(result1) == "e3", "e1 * e2 should be e3"
# テストケース2: e2 * e1
expr2 = Product(e2, e1)
result2 = replace_o(expr2)
print(f"e2 * e1 = {to_string(result2)}")
assert to_string(result2) == "-e3", "e2 * e1 should be -e3"
# テストケース3: (e1 + e2) * e4
expr3 = Product(Sum(e1, e2), e4)
expanded3 = expand(expr3)
result3 = replace_o(expanded3)
print(f"(e1 + e2) * e4 = {to_string(result3)}")
assert to_string(result3) == "e5 + e6", "(e1 + e2) * e4 should be e5 + e6"
# テストケース4: e1 * (e2 + e3)
expr4 = Product(e1, Sum(e2, e3))
expanded4 = expand(expr4)
result4 = replace_o(expanded4)
print(f"e1 * (e2 + e3) = {to_string(result4)}")
assert to_string(result4) == "e3 - e2", "e1 * (e2 + e3) should be e3 - e2"
# テストケース5: e1 * e2 * e4
expr5 = Product(e1, e2, e4)
result5 = replace_o(expr5)
print(f"e1 * e2 * e4 = {to_string(result5)}")
assert to_string(result5) == "e7", "e1 * e2 * e4 should be e7"
print("All octonion tests passed!")
S=N("sinθ")
C=N("cosθ")
S2=N("sin2θ")
C2=N("cos2θ")
def replace_sc(expr):
expr = replace(expr, 1*S, S)
expr = replace(expr, 1*C, C)
expr = replace(expr, S*C, C*S)
expr = replace(expr, C*C+S*S, 1)
expr = replace(expr, -1*C*S+C*S, 0)
expr = replace(expr, C*S+C*S, S2)
expr = replace(expr, -1*S*S+C*C, C2)
return expr
x1, x2, x3, x4, x5, x6, x7 = syms("x1,x2,x3,x4,x5,x6,x7")
xs = [x1, x2, x3, x4, x5, x6, x7]
def set_ov():
def f(expr):
expr = arrange_o(expand(expr))
expr = flatten(collect(expr, *xs))
expr = flatten(collect(expr, -1))
expr = forward(expr)
expr = replace_sc(expr)
expr = replace(expr, 0*x1 + 1*x1*e1, x1*e1)
expr = forward(forward(expr, *xs))
return expr
global O, O_, V, ov, ov_o, vo, o_vo
O = C + S*e1
O_ = forward(conj_o(O))
V = x1*e1 + x2*e2 + x3*e3 + x4*e4 + x5*e5 + x6*e6 + x7*e7
ov = replace_o(expand(O * V))
ov_o = f(ov * O_)
ov = arrange_o(ov)
vo = replace_o(expand(V * O_))
o_vo = f(O * vo)
vo = arrange_o(vo)
print("O =", repr(O))
print("O_ =", repr(O_))
print("V =", repr(V))
print("ov =", repr(ov))
print("ov_o =", repr(ov_o))
print("vo =", repr(vo))
print("o_vo =", repr(o_vo))
def run_all_tests():
test_to_string()
test_replace()
run_tests1()
run_tests2()
test_collect()
test_replace_q()
test_replace_o()
def main():
print("シンボル計算ライブラリデモ")
print("N(...)で単一の要素を表現し、+ と * で加算と乗算を表現します。")
print("例: N(2) * (N(3) + N(4)) または (a + b) * (c + d)")
print("\nテストケース結果:")
run_all_tests()
repl()
def repl():
print("\n数式を入力してください(終了するには [Ctrl]+[D]):")
for line in sys.stdin:
try:
expr = eval(line.strip())
result = to_string(expr)
print(f"結果: {result}")
expanded = expand(expr)
expanded_result = to_string(expanded)
print(f"展開結果: {expanded_result}")
except Exception as e:
print(f"エラー: {str(e)}")
print("\n数式を入力してください(終了するには [Ctrl]+[D]):")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment