-
-
Save eric-wieser/823a115df88b6e96e21fdedebbdfa9c0 to your computer and use it in GitHub Desktop.
Geometric algebra operators within sympy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# derived from sympy's matexpr.py. Very incomplete | |
from sympy.core import S | |
from sympy.core import Basic, Expr, Add, Mul, Symbol | |
from sympy.core.operations import AssocOp | |
from sympy.core.decorators import _sympifyit, call_highest_priority | |
from sympy.core.compatibility import string_types | |
class GaExpr(Expr): | |
""" | |
Base class for all multivector expressions, which overloads operators. | |
Note that even the `+` operator must be overloaded, so that the result of | |
``(a + b)`` itself is a :class:`GaExpr`. | |
""" | |
_iterable = False | |
is_commutative = False | |
is_number = False | |
is_symbol = False | |
is_scalar = False | |
def __neg__(self): | |
return GaMul(S.NegativeOne, self).doit() | |
@_sympifyit('other', NotImplemented) | |
@call_highest_priority('__radd__') | |
def __add__(self, other): | |
return GaAdd(self, other) | |
@_sympifyit('other', NotImplemented) | |
@call_highest_priority('__add__') | |
def __radd__(self, other): | |
return GaAdd(other, self) | |
@_sympifyit('other', NotImplemented) | |
@call_highest_priority('__rsub__') | |
def __sub__(self, other): | |
return GaAdd(self, -other) | |
@_sympifyit('other', NotImplemented) | |
@call_highest_priority('__sub__') | |
def __rsub__(self, other): | |
return GaAdd(other, -self) | |
@_sympifyit('other', NotImplemented) | |
@call_highest_priority('__rmul__') | |
def __mul__(self, other): | |
return GaMul(self, other) | |
@_sympifyit('other', NotImplemented) | |
@call_highest_priority('__mul__') | |
def __rmul__(self, other): | |
return GaMul(other, self) | |
@_sympifyit('other', NotImplemented) | |
@call_highest_priority('__rxor__') | |
def __xor__(self, other): | |
return GaWedge(self, other) | |
@_sympifyit('other', NotImplemented) | |
@call_highest_priority('__xor__') | |
def __rxor__(self, other): | |
return GaWedge(other, self) | |
@_sympifyit('other', NotImplemented) | |
@call_highest_priority('__rxor__') | |
def __or__(self, other): | |
return GaDot(self, other) | |
@_sympifyit('other', NotImplemented) | |
@call_highest_priority('__xor__') | |
def __ror__(self, other): | |
return GaDot(other, self) | |
class GaAdd(GaExpr, Add): | |
pass | |
class GaMul(GaExpr, Mul): | |
pass | |
class GaDot(GaExpr): | |
# not associative, so require exactly two items | |
def __new__(cls, a, b, **kwargs): | |
return Basic.__new__(cls, a, b, **kwargs) | |
class GaWedge(GaExpr, AssocOp): | |
is_commutative = False | |
identity = GaAdd() | |
@classmethod | |
def flatten(cls, args): | |
c_part, nc_part, order_symbols = super(GaWedge, cls).flatten(args) | |
seen = set() | |
for a in nc_part: | |
if a in seen: | |
return [], [GaAdd()], None | |
seen.add(a) | |
return c_part, nc_part, order_symbols | |
class GaSymbol(GaExpr): | |
is_symbol = True | |
is_commutative = False | |
def __new__(cls, name): | |
if isinstance(name, string_types): | |
name = Symbol(name) | |
return Basic.__new__(cls, name) | |
@property | |
def name(self): | |
return self.args[0].name | |
# now we monkey-patch the latex printer to know about our new types | |
from sympy.printing.latex import LatexPrinter as _LatexPrinter | |
def _print_GaWedge(self, expr): | |
tex = "" | |
for i, term in enumerate(expr.args): | |
if i == 0: | |
pass | |
else: | |
tex += r" \wedge " | |
term_tex = self._print(term) | |
if self._needs_add_brackets(term): | |
term_tex = r"\left(%s\right)" % term_tex | |
tex += term_tex | |
return tex | |
def _print_GaDot(self, expr): | |
tex = "" | |
for i, term in enumerate(expr.args): | |
if i == 0: | |
pass | |
else: | |
tex += r" \cdot " | |
term_tex = self._print(term) | |
if self._needs_add_brackets(term): | |
term_tex = r"\left(%s\right)" % term_tex | |
tex += term_tex | |
return tex | |
def _print_GaSymbol(self, expr): | |
return self._print_Symbol(expr, style=self._settings['ga_symbol_style']) | |
_LatexPrinter._default_settings['ga_symbol_style'] = 'plain' | |
_LatexPrinter._print_GaSymbol = _print_GaSymbol | |
_LatexPrinter._print_GaWedge = _print_GaWedge | |
_LatexPrinter._print_GaDot = _print_GaDot | |
_LatexPrinter._print_GaAdd = _LatexPrinter._print_Add | |
_LatexPrinter._print_GaMul = _LatexPrinter._print_Mul |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Using a very basic example: