Last active
June 25, 2024 14:33
-
-
Save 7shi/912ad67f3edd07278b61ecb2fb0acf26 to your computer and use it in GitHub Desktop.
[py] calculate quaternions and octonions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
def is_number(n): | |
return isinstance(n, (int, float)) | |
def conv_n(n): | |
return N(n) if is_number(n) else n | |
def conv_ns(ns): | |
return [conv_n(n) for n in ns] | |
class N: | |
def __init__(self, value): | |
self.value = value | |
def __str__(self): | |
return f"N({repr(self.value)})" | |
def __repr__(self): | |
return to_string(self) | |
def __add__(self, other): | |
return Sum(self, conv_n(other)) | |
def __radd__(self, other): | |
return Sum(conv_n(other), self) | |
def __neg__(self): | |
return -1 * self | |
def __sub__(self, other): | |
return self + (-conv_n(other)) | |
def __rsub__(self, other): | |
return conv_n(other) + (-self) | |
def __mul__(self, other): | |
return Product(self, conv_n(other)) | |
def __rmul__(self, other): | |
return Product(conv_n(other), self) | |
class Sum: | |
def __init__(self, *values): | |
if len(values) == 1 and hasattr(values[0], '__iter__'): | |
self.values = list(values[0]) | |
else: | |
self.values = list(values) | |
def __str__(self): | |
return f"Sum({', '.join(map(str, self.values))})" | |
def __repr__(self): | |
return to_string(self) | |
def __add__(self, other): | |
if isinstance(other, Sum): | |
return Sum(*self.values, *other.values) | |
return Sum(*self.values, conv_n(other)) | |
def __radd__(self, other): | |
if isinstance(other, Sum): | |
return Sum(*other.values, *self.values) | |
return Sum(conv_n(other), *self.values) | |
def __neg__(self): | |
return -1 * self | |
def __sub__(self, other): | |
return self + (-conv_n(other)) | |
def __rsub__(self, other): | |
return conv_n(other) + (-self) | |
def __mul__(self, other): | |
return Product(self, conv_n(other)) | |
def __rmul__(self, other): | |
return Product(conv_n(other), self) | |
class Product: | |
def __init__(self, *values): | |
if len(values) == 1 and hasattr(values[0], '__iter__'): | |
self.values = list(values[0]) | |
else: | |
self.values = list(values) | |
def __str__(self): | |
return f"Product({', '.join(map(str, self.values))})" | |
def __repr__(self): | |
return to_string(self) | |
def __add__(self, other): | |
return Sum(self, conv_n(other)) | |
def __radd__(self, other): | |
return Sum(conv_n(other), self) | |
def __neg__(self): | |
if len(self.values) and eq_expr(self.values[0], N(-1)): | |
return Product(*self.values[1:]) | |
return -1 * self | |
def __sub__(self, other): | |
return self + (-conv_n(other)) | |
def __rsub__(self, other): | |
return conv_n(other) + (-self) | |
def __mul__(self, other): | |
if isinstance(other, Product): | |
return Product(*self.values, *other.values) | |
return Product(*self.values, conv_n(other)) | |
def __rmul__(self, other): | |
if isinstance(other, Product): | |
return Product(*other.values, *self.values) | |
return Product(conv_n(other), *self.values) | |
def syms(symbols_str): | |
symbols = symbols_str.split(',') | |
return tuple(N(sym) for sym in symbols) | |
# a~zをN("a"),...N("z")として定義 | |
a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z = syms("a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x,y,z") | |
def to_string(expr): | |
if isinstance(expr, N): | |
return str(expr.value) | |
elif isinstance(expr, Sum): | |
ret = "" | |
for term in expr.values: | |
s = to_string(term) | |
if isinstance(term, Sum): | |
if ret: | |
ret += " + " | |
ret += f"({s})" | |
elif ret and s.startswith("-"): | |
ret += f" - {s[1:]}" | |
else: | |
if ret: | |
ret += " + " | |
ret += s | |
return ret | |
elif isinstance(expr, Product): | |
factors = expr.values | |
ret = "" | |
if len(factors) > 1 and eq_expr(factors[0], N(-1)): | |
ret += "-" | |
factors = factors[1:] | |
for i, factor in enumerate(factors): | |
if i: | |
ret += " * " | |
f = to_string(factor) | |
if isinstance(factor, (Sum, Product)) or (ret and f.startswith("-")): | |
ret += f"({f})" | |
else: | |
ret += f | |
return ret | |
return str(expr) | |
# テスト関数 | |
def test_to_string(): | |
print("Testing improved to_string function:") | |
# 基本的な数値 | |
n = N(5) | |
assert to_string(n) == "5", "Basic number test failed" | |
# 単純な積 | |
p = Product(N(2), N('x')) | |
assert to_string(p) == "2 * x", "Simple product test failed" | |
# ネストされた積 | |
np = Product(N(2), Product(N('x'), N('y'))) | |
assert to_string(np) == "2 * (x * y)", "Nested product test failed" | |
# 複数のネストされた積 | |
nnp = Product(N(2), Product(N('x'), N('y')), Product(N('a'), N('b'))) | |
assert to_string(nnp) == "2 * (x * y) * (a * b)", "Multiple nested products test failed" | |
# 和 | |
s = Sum(N(1), N(2)) | |
assert to_string(s) == "1 + 2", "Simple sum test failed" | |
# 積を含む和 | |
sp = Sum(N(1), Product(N(2), N('x'))) | |
assert to_string(sp) == "1 + 2 * x", "Sum with product test failed" | |
# 和を含む積 | |
ps = Product(N(2), Sum(N('x'), N('y'))) | |
assert to_string(ps) == "2 * (x + y)", "Product with sum test failed" | |
print("All to_string tests passed!") | |
def product_to_sum(xs, ys): | |
from itertools import product | |
result = [] | |
for x, y in product(xs, ys): | |
fs = [] | |
if isinstance(x, Product): | |
fs += x.values | |
else: | |
fs.append(x) | |
if isinstance(y, Product): | |
fs += y.values | |
else: | |
fs.append(y) | |
result.append(Product(*fs)) | |
return result | |
def expand(expr): | |
if isinstance(expr, Sum): | |
return Sum(expand(term) for term in expr.values) | |
elif isinstance(expr, Product): | |
sums = [] | |
for factor in expr.values: | |
factor = expand(factor) | |
if isinstance(factor, Sum): | |
if not sums: | |
sums = factor.values | |
else: | |
sums = product_to_sum(sums, factor.values) | |
elif not sums: | |
sums.append(factor) | |
else: | |
sums = product_to_sum(sums, [factor]) | |
return Sum(*sums) if len(sums) != 1 else sums[0] | |
return expr | |
def run_tests1(): | |
test_cases = [ | |
('N(2) * (N(3) + N(4))', '2 * 3 + 2 * 4'), | |
('(N(1) + N(2)) * (N(3) + N(4))', '1 * 3 + 1 * 4 + 2 * 3 + 2 * 4'), | |
('N(2) * N(3) * (N(4) + N(5))', '2 * 3 * 4 + 2 * 3 * 5'), | |
('N(1) + N(2) * (N(3) + N(4))', '1 + 2 * 3 + 2 * 4'), | |
('(a + b) * (c + d)', 'a * c + a * d + b * c + b * d'), | |
('x * (N(1) + N(2))', 'x * 1 + x * 2'), | |
('a + b * (c + d)', 'a + b * c + b * d'), | |
('f * (g + h)', 'f * g + f * h'), | |
('N(1) * 2 + 3', '1 * 2 + 3'), | |
('2 * N(3) + 4', '2 * 3 + 4'), | |
('(N(1) + 2) * (3 + N(4))', '1 * 3 + 1 * 4 + 2 * 3 + 2 * 4'), | |
] | |
for input_expr, expected in test_cases: | |
expr = eval(input_expr) | |
expanded = flatten(expand(expr)) | |
result = to_string(expanded) | |
passed = result == expected | |
print(f"{'✅' if passed else '❌'} {input_expr} => {result}") | |
def flatten(expr): | |
if isinstance(expr, (Sum, Product)): | |
result = [] | |
for x in expr.values: | |
x = flatten(x) | |
if isinstance(x, type(expr)): | |
result += x.values | |
else: | |
result.append(x) | |
return type(expr)(*result) if len(result) != 1 else result[0] | |
return expr | |
def eq_expr(expr1, expr2): | |
if type(expr1) != type(expr2): | |
return False | |
if isinstance(expr1, N): | |
return expr1.value == expr2.value | |
elif isinstance(expr1, Sum): | |
if len(expr1.values) != len(expr2.values): | |
return False | |
return all(eq_expr(t1, t2) for t1, t2 in zip(sorted(expr1.values, key=lambda x: str(x)), sorted(expr2.values, key=lambda x: str(x)))) | |
elif isinstance(expr1, Product): | |
if len(expr1.values) != len(expr2.values): | |
return False | |
return all(eq_expr(f1, f2) for f1, f2 in zip(expr1.values, expr2.values)) | |
else: | |
return expr1 == expr2 | |
def forward(expr, *values): | |
values = conv_ns(values) | |
def rec(expr): | |
if isinstance(expr, Sum): | |
return Sum(rec(term) for term in expr.values) | |
if isinstance(expr, Product): | |
nums = [] | |
others = [] | |
for factor in expr.values: | |
if values and any(eq_expr(factor, value) for value in values): | |
nums.append(factor) | |
elif not values and isinstance(factor, N) and is_number(factor.value): | |
nums.append(factor) | |
else: | |
others.append(rec(factor)) | |
return Product(*nums, *others) | |
return expr | |
return rec(expr) | |
def backward(expr, *values): | |
values = conv_ns(values) | |
def rec(expr): | |
if isinstance(expr, Sum): | |
return Sum(rec(term) for term in expr.values) | |
if isinstance(expr, Product): | |
others = [] | |
backs = [] | |
for factor in expr.values: | |
if any(eq_expr(factor, value) for value in values): | |
backs.append(factor) | |
else: | |
others.append(rec(factor)) | |
return Product(*others, *backs) | |
return expr | |
return rec(expr) | |
def find(nodes1, nodes2): | |
for i in range(len(nodes1) - len(nodes2) + 1): | |
if all(eq_expr(nodes1[i + j], nodes2[j]) for j in range(len(nodes2))): | |
return i | |
return -1 | |
def replace(expr, pattern, replacement): | |
replacement = conv_n(replacement) | |
def rec(expr): | |
if eq_expr(expr, pattern): | |
return replacement | |
if isinstance(expr, (Sum, Product)): | |
vs1 = expr.values | |
if isinstance(expr, type(pattern)): | |
same = isinstance(pattern, type(replacement)) | |
vs2 = pattern.values | |
result = [] | |
i = 0 | |
last = len(vs1) - len(vs2) | |
while i < len(vs1): | |
if i <= last and all(eq_expr(vs1[i + j], v) for j, v in enumerate(vs2)): | |
if same: | |
result += replacement.values | |
else: | |
result.append(replacement) | |
i += len(vs2) | |
else: | |
result.append(rec(vs1[i])) | |
i += 1 | |
else: | |
result = [rec(e) for e in vs1] | |
return type(expr)(*result) if len(result) != 1 else result[0] | |
return expr | |
return rec(expr) | |
# テスト関数 | |
def test_replace(): | |
I, J, K = syms("I,J,K") | |
# テストケース1: 単純な置換 | |
expr1 = Product(I, J) | |
result1 = replace(expr1, Product(I, J), K) | |
print(f"IJ replaced: {to_string(result1)}") | |
assert to_string(result1) == "K", "Simple replacement failed" | |
# テストケース2: ネストされたProductのフラット化1 | |
expr2 = Product(I, Product(J, K), I) | |
result2 = replace(expr2, Product(J, K), I) | |
print(f"I(JK)I replaced: {to_string(result2)}") | |
assert to_string(result2) == "I * I * I", "Nested Product flattening failed" | |
# テストケース3: ネストされたProductのフラット化2 | |
expr3 = Product(K, J, I) | |
result3 = replace(expr3, Product(J, I), Product(N(-1), K)) | |
print(f"KJI double replaced: {to_string(result3)}") | |
assert to_string(result3) == "K * (-1) * K", "Nested Product flattening failed" | |
print("All replace tests passed!") | |
def run_tests2(): | |
print("Running tests...") | |
# flatten のテスト | |
print("\nTesting flatten:") | |
a, b, c = syms("a,b,c") | |
expr1 = Product(a, Product(b, c)) | |
flattened1 = flatten(expr1) | |
print(f"Original: {to_string(expr1)}") | |
print(f"Flattened: {to_string(flattened1)}") | |
assert to_string(flattened1) == "a * b * c", "flatten test 1 failed" | |
expr2 = Sum(Product(a, Product(b, c)), Product(Product(a, b), c)) | |
flattened2 = flatten(expr2) | |
print(f"Original: {to_string(expr2)}") | |
print(f"Flattened: {to_string(flattened2)}") | |
assert to_string(flattened2) == "a * b * c + a * b * c", "flatten test 2 failed" | |
# back のテスト | |
print("\nTesting back:") | |
x, y, z = syms("x,y,z") | |
expr3 = Sum(Product(x, y, z), Product(y, x, z)) | |
backed3 = backward(expr3, x) | |
print(f"Original: {to_string(expr3)}") | |
print(f"Backed: {to_string(backed3)}") | |
assert to_string(backed3) == "y * z * x + y * z * x", "back test failed" | |
# replace のテスト | |
print("\nTesting replace:") | |
i, x1, x2, y1, y2 = syms("i,x1,x2,y1,y2") | |
expr4 = (x1 + y1 * i) * (x2 + y2 * i) | |
expr4 = expand(expr4) | |
expr4 = backward(expr4, i) | |
replaced4 = replace(expr4, Product(i, i), N(-1)) | |
print(f"Original: {to_string(expr4)}") | |
print(f"Replaced: {to_string(replaced4)}") | |
assert to_string(replaced4) == "x1 * x2 + x1 * y2 * i + y1 * x2 * i + y1 * y2 * (-1)", "replace test failed" | |
print("\nAll tests passed!") | |
def collect(expr, *symbols, forward=False, forward_other=False): | |
symbols = conv_ns(symbols) | |
def rec(expr): | |
if isinstance(expr, Product): | |
return Product(rec(factor) for factor in expr.values) | |
if isinstance(expr, Sum): | |
collected = {} | |
collected_syms = {} | |
other_terms = [] | |
for term in expr.values: | |
term = rec(term) | |
if isinstance(term, Product): | |
others = [] | |
syms = [] | |
for factor in term.values: | |
if any(eq_expr(factor, sym) for sym in symbols): | |
syms.append(factor) | |
else: | |
others.append(factor) | |
if syms: | |
s = to_string(Product(*syms)) | |
if s not in collected: | |
collected[s] = [] | |
collected_syms[s] = syms | |
collected[s].append(Product(*others) if len(others) != 1 else others[0]) | |
else: | |
other_terms.append(term) | |
else: | |
other_terms.append(term) | |
result = [] | |
if forward_other: | |
result += other_terms | |
for s in sorted(collected): | |
vs = collected[s] | |
ks = Sum(*vs) if len(vs) != 1 else vs[0] | |
if forward: | |
result.append(Product(*collected_syms[s], ks)) | |
else: | |
result.append(Product(ks, *collected_syms[s])) | |
if not forward_other: | |
result += other_terms | |
return Sum(*result) if len(result) != 1 else result[0] | |
return expr | |
return rec(expr) | |
def test_collect(): | |
x1, x2, y1, y2, i = syms("x1,x2,y1,y2,i") | |
# テストケース1 | |
expr = Sum( | |
Product(x1, x2), | |
Product(x1, y2, i), | |
Product(y1, x2, i), | |
Product(y1, y2, N(-1)) | |
) | |
print("Test case 1:") | |
print("Original expression:") | |
print(to_string(expr)) | |
collected = collect(expr, i) | |
print("\nAfter collecting terms with respect to i:") | |
print(to_string(collected)) | |
expected = Sum( | |
Product(Sum(Product(x1, y2), Product(y1, x2)), i), | |
Product(x1, x2), | |
Product(y1, y2, N(-1)) | |
) | |
assert to_string(collected) == to_string(expected), f"Test case 1 failed. Expected {to_string(expected)}, but got {to_string(collected)}" | |
print("Test case 1 passed.") | |
# テストケース2 | |
expr2 = Sum( | |
Product(N(2), x1, i), | |
Product(N(3), x2, i), | |
Product(N(4), x1), | |
Product(N(5), x2) | |
) | |
print("\nTest case 2:") | |
print("Original expression:") | |
print(to_string(expr2)) | |
collected2 = collect(expr2, i) | |
print("\nAfter collecting terms with respect to i:") | |
print(to_string(collected2)) | |
expected2 = Sum( | |
Product(Sum(Product(N(2), x1), Product(N(3), x2)), i), | |
Product(N(4), x1), | |
Product(N(5), x2) | |
) | |
assert to_string(collected2) == to_string(expected2), f"Test case 2 failed. Expected {to_string(expected2)}, but got {to_string(collected2)}" | |
print("Test case 2 passed.") | |
print("\nAll test cases passed successfully!") | |
def simplify(expr): | |
expr = forward(flatten(expand(flatten(expand(expr))))) | |
terms = None | |
if isinstance(expr, Sum): | |
nterms = [] | |
for term in expr.values: | |
n = 1 | |
if isinstance(term, Product): | |
while len(term.values) and isinstance(n0 := term.values[0], N) and is_number(n0.value): | |
n *= n0.value | |
term = Product(*term.values[1:]) | |
found = False | |
for t in nterms: | |
if eq_expr(t[1], term): | |
t[0] += n | |
found = True | |
break | |
if not found: | |
nterms.append([n, term]) | |
terms = [t if n == 1 else n * t for n, t in nterms if n] | |
elif hasattr(expr, "values"): | |
terms = [simplify(x) for x in expr.values] | |
if terms is not None: | |
l = len(terms) | |
return N(0) if l == 0 else terms[0] if l == 1 else type(expr)(*terms) | |
return expr | |
commons = [ | |
(N(1) * 1, N(1)), | |
(N(1) * -1, N(-1)), | |
(N(-1) * -1, N(1)), | |
] | |
# 四元数の基本単位を定義 | |
I, J, K = syms("I,J,K") | |
quaternions = [I, J, K] | |
qreps = [*commons] | |
def init_q(): | |
for q in quaternions: | |
qreps.append((N(1) * q, q)) | |
qreps.append((q * q, N(-1))) | |
qreps.extend([ | |
(I * J, K), | |
(J * I, N(-1) * K), | |
(J * K, I), | |
(K * J, N(-1) * I), | |
(K * I, J), | |
(I * K, N(-1) * J), | |
]) | |
init_q() | |
def replace_q(expr): | |
while True: | |
new_expr = backward(forward(expr), *quaternions) | |
for pattern, replacement in qreps: | |
new_expr = replace(new_expr, pattern, replacement) | |
if eq_expr(new_expr, expr): | |
break | |
expr = new_expr | |
return expr | |
def conj_q(expr): | |
for q in quaternions: | |
expr = replace(expr, q, N(-1) * q) | |
return expr | |
# テスト関数 | |
def test_replace_q(): | |
print("Testing quaternion replacements:") | |
# テストケース1: IJ | |
expr1 = Product(I, J) | |
result1 = replace_q(expr1) | |
print(f"IJ = {to_string(result1)}") | |
assert to_string(result1) == "K", "IJ should be K" | |
# テストケース2: JI | |
expr2 = Product(J, I) | |
result2 = replace_q(expr2) | |
print(f"JI = {to_string(result2)}") | |
assert to_string(result2) == "-K", "JI should be -K" | |
# テストケース3: (I + J)K | |
expr3 = Product(Sum(I, J), K) | |
expanded3 = expand(expr3) | |
result3 = replace_q(expanded3) | |
print(f"(I + J)K = {to_string(result3)}") | |
assert to_string(result3) == "-J + I", "(I + J)K should be -J + I" | |
# テストケース4: I(J + K) | |
expr4 = Product(I, Sum(J, K)) | |
expanded4 = expand(expr4) | |
result4 = replace_q(expanded4) | |
print(f"I(J + K) = {to_string(result4)}") | |
assert to_string(result4) == "K - J", "I(J + K) should be K - J" | |
# テストケース5: IJK | |
expr5 = Product(I, J, K) | |
result5 = replace_q(expr5) | |
print(f"IJK = {to_string(result5)}") | |
assert to_string(result5) == "-1", "IJK should be -1" | |
# テストケース6: IJKJI | |
expr6 = Product(I, J, K, J, I) | |
result6 = replace_q(expr6) | |
print(f"IJKJI = {to_string(result6)}") | |
assert to_string(result6) == "K", "IJKJI should be K" | |
# テストケース7: IIJJKK | |
expr7 = Product(I, I, J, J, K, K) | |
result7 = replace_q(expr7) | |
print(f"IIJJKK = {to_string(result7)}") | |
assert to_string(result7) == "-1", "IIJJKK should be -1" | |
print("All quaternion tests passed!") | |
# 八元数の基本単位を定義 | |
e1, e2, e3, e4, e5, e6, e7 = syms("e1,e2,e3,e4,e5,e6,e7") | |
octonions = [e1, e2, e3, e4, e5, e6, e7] | |
triads = [ | |
[1, 2, 3], | |
[1, 4, 5], | |
[1, 7, 6], | |
[2, 4, 6], | |
[2, 5, 7], | |
[3, 4, 7], | |
[3, 6, 5] | |
] | |
def mulo(a, b): | |
if a == 0: | |
return [1, b] | |
if b == 0: | |
return [1, a] | |
if a == b: | |
return [-1, 0] | |
for triad in triads: | |
if a in triad and b in triad: | |
c = next(x for x in triad if x != a and x != b) | |
s = 1 if (triad.index(b) - triad.index(a) + 3) % 3 == 1 else -1 | |
return [s, c] | |
return [0, 0] # エラーケース | |
oreps = [*commons] | |
def init_o(): | |
for i in range(1, 8): | |
o1 = octonions[i - 1] | |
oreps.append((N(1) * o1, o1)) | |
for j in range(1, 8): | |
o2 = octonions[j - 1] | |
result = mulo(i, j) | |
if result[1] == 0: | |
oreps.append((o1 * o2, N(result[0]))) | |
else: | |
o3 = octonions[result[1] - 1] | |
if result[0] == 1: | |
oreps.append((o1 * o2, o3)) | |
else: | |
oreps.append((o1 * o2, N(result[0]) * o3)) | |
init_o() | |
def replace_o(expr): | |
while True: | |
new_expr = backward(forward(expr), *octonions) | |
for pattern, replacement in oreps: | |
new_expr = replace(new_expr, pattern, replacement) | |
if eq_expr(new_expr, expr): | |
break | |
expr = new_expr | |
return expr | |
def conj_o(expr): | |
for o in octonions: | |
expr = replace(expr, Product(o), N(-1) * o) | |
return expr | |
def arrange_o(expr): | |
expr = replace_o(expr) | |
expr = collect(expr, *octonions, forward_other=True) | |
expr = forward(expr, *xs) | |
expr = forward(expr) | |
return expr | |
# テスト関数 | |
def test_replace_o(): | |
print("Testing octonion replacements:") | |
# テストケース1: e1 * e2 | |
expr1 = Product(e1, e2) | |
result1 = replace_o(expr1) | |
print(f"e1 * e2 = {to_string(result1)}") | |
assert to_string(result1) == "e3", "e1 * e2 should be e3" | |
# テストケース2: e2 * e1 | |
expr2 = Product(e2, e1) | |
result2 = replace_o(expr2) | |
print(f"e2 * e1 = {to_string(result2)}") | |
assert to_string(result2) == "-e3", "e2 * e1 should be -e3" | |
# テストケース3: (e1 + e2) * e4 | |
expr3 = Product(Sum(e1, e2), e4) | |
expanded3 = expand(expr3) | |
result3 = replace_o(expanded3) | |
print(f"(e1 + e2) * e4 = {to_string(result3)}") | |
assert to_string(result3) == "e5 + e6", "(e1 + e2) * e4 should be e5 + e6" | |
# テストケース4: e1 * (e2 + e3) | |
expr4 = Product(e1, Sum(e2, e3)) | |
expanded4 = expand(expr4) | |
result4 = replace_o(expanded4) | |
print(f"e1 * (e2 + e3) = {to_string(result4)}") | |
assert to_string(result4) == "e3 - e2", "e1 * (e2 + e3) should be e3 - e2" | |
# テストケース5: e1 * e2 * e4 | |
expr5 = Product(e1, e2, e4) | |
result5 = replace_o(expr5) | |
print(f"e1 * e2 * e4 = {to_string(result5)}") | |
assert to_string(result5) == "e7", "e1 * e2 * e4 should be e7" | |
print("All octonion tests passed!") | |
S=N("sinθ") | |
C=N("cosθ") | |
S2=N("sin2θ") | |
C2=N("cos2θ") | |
def replace_sc(expr): | |
expr = replace(expr, 1*S, S) | |
expr = replace(expr, 1*C, C) | |
expr = replace(expr, S*C, C*S) | |
expr = replace(expr, C*C+S*S, 1) | |
expr = replace(expr, -1*C*S+C*S, 0) | |
expr = replace(expr, C*S+C*S, S2) | |
expr = replace(expr, -1*S*S+C*C, C2) | |
return expr | |
x1, x2, x3, x4, x5, x6, x7 = syms("x1,x2,x3,x4,x5,x6,x7") | |
xs = [x1, x2, x3, x4, x5, x6, x7] | |
def set_ov(): | |
def f(expr): | |
expr = arrange_o(expand(expr)) | |
expr = flatten(collect(expr, *xs)) | |
expr = flatten(collect(expr, -1)) | |
expr = forward(expr) | |
expr = replace_sc(expr) | |
expr = replace(expr, 0*x1 + 1*x1*e1, x1*e1) | |
expr = forward(forward(expr, *xs)) | |
return expr | |
global O, O_, V, ov, ov_o, vo, o_vo | |
O = C + S*e1 | |
O_ = forward(conj_o(O)) | |
V = x1*e1 + x2*e2 + x3*e3 + x4*e4 + x5*e5 + x6*e6 + x7*e7 | |
ov = replace_o(expand(O * V)) | |
ov_o = f(ov * O_) | |
ov = arrange_o(ov) | |
vo = replace_o(expand(V * O_)) | |
o_vo = f(O * vo) | |
vo = arrange_o(vo) | |
print("O =", repr(O)) | |
print("O_ =", repr(O_)) | |
print("V =", repr(V)) | |
print("ov =", repr(ov)) | |
print("ov_o =", repr(ov_o)) | |
print("vo =", repr(vo)) | |
print("o_vo =", repr(o_vo)) | |
def run_all_tests(): | |
test_to_string() | |
test_replace() | |
run_tests1() | |
run_tests2() | |
test_collect() | |
test_replace_q() | |
test_replace_o() | |
def main(): | |
print("シンボル計算ライブラリデモ") | |
print("N(...)で単一の要素を表現し、+ と * で加算と乗算を表現します。") | |
print("例: N(2) * (N(3) + N(4)) または (a + b) * (c + d)") | |
print("\nテストケース結果:") | |
run_all_tests() | |
repl() | |
def repl(): | |
print("\n数式を入力してください(終了するには [Ctrl]+[D]):") | |
for line in sys.stdin: | |
try: | |
expr = eval(line.strip()) | |
result = to_string(expr) | |
print(f"結果: {result}") | |
expanded = expand(expr) | |
expanded_result = to_string(expanded) | |
print(f"展開結果: {expanded_result}") | |
except Exception as e: | |
print(f"エラー: {str(e)}") | |
print("\n数式を入力してください(終了するには [Ctrl]+[D]):") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment