Computer algebra in 99 lines of Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Usage: | |
# python cas.py simplify "(x+y)**30*(2*x)+123*z**8" | |
# python cas.py solve "3*x**2 == -x**2 + 5*x" | |
import math, sys | |
from ast import * | |
from functools import reduce | |
from collections import Counter, defaultdict | |
def parse_mv(mv): | |
def pow(n, e): | |
return Name(n) if e == 1 else BinOp(Name(n), Pow(), Constant(e)) | |
def term(es, c, first = False): | |
fs = [pow(n, e) for n, e in es] | |
fs = [Constant(abs(c))] + fs if abs(c) != 1 else fs | |
fs[0] = UnaryOp(USub(), fs[0]) if first and c < 0 else fs[0] | |
return reduce(lambda x, y: BinOp(x, Mult(), y), fs[1:], fs[0]) | |
mv = list(reversed(sorted([(es, c) for (es, c) in mv.items() if c]))) | |
ret = term(*mv[0], first = True) if mv else Constant(0) | |
join = {True : Add(), False : Sub()} | |
return reduce(lambda ret, el: BinOp(ret, join[el[1] > 0], term(*el)), | |
mv[1:], ret) | |
def add(mv1, mv2): | |
return defaultdict(int, {c : mv1[c] + mv2[c] | |
for c in list(mv1) + list(mv2)}) | |
def sub(mv1, mv2): | |
return defaultdict(int, {c : mv1[c] - mv2[c] | |
for c in list(mv1) + list(mv2)}) | |
def pow(mv1, mv2): | |
items = list(mv2.items()) | |
assert len(items) == 1 and items[0][0] == () and items[0][1] > 0, \ | |
'Non-negative integer powers only' | |
mv3 = defaultdict(int, {() : 1}) | |
return reduce(lambda x, y: mul(x, mv1), range(items[0][1]), mv3) | |
def mul(mv1, mv2): | |
mv3 = defaultdict(int) | |
for es1, c1 in mv1.items(): | |
for es2, c2 in mv2.items(): | |
es3 = Counter(dict(es1)) | |
es3.update(dict(es2)) | |
es3 = tuple(sorted(es3.items())) | |
mv3[es3] += c1*c2 | |
return mv3 | |
def eval(tree): | |
tp = type(tree) | |
if tp == BinOp: | |
BINOPS = {Add : add, Sub : sub, Mult : mul, Pow : pow} | |
l, r = eval(tree.left), eval(tree.right) | |
return BINOPS[type(tree.op)](l, r) | |
elif tp == Name: | |
return defaultdict(int, {((tree.id, 1),) : 1}) | |
elif tp == Constant: | |
return defaultdict(int, {() : tree.value}) | |
elif tp == UnaryOp: | |
mv = eval(tree.operand) | |
return mul(mv, defaultdict(int, {() : -1})) | |
def roots_1st(cs): | |
q = -cs[0] / cs[1] | |
if q == int(q): | |
return Constant(int(q)) | |
return BinOp(Constant(-cs[0]), Div(), Constant(cs[1])) | |
def roots_2nd(cs): | |
d = cs[1]**2 - 4*cs[2]*cs[0] | |
da, sqrt_da = abs(d), math.sqrt(abs(d)) | |
sqrt_node = Call(Name('sqrt'), [Constant(da)], []) | |
if sqrt_da == int(sqrt_da): | |
sqrt_node = Constant(int(sqrt_da)) | |
tree = BinOp(Name('pm'), Mult(), Name('i')) if d < 0 else Name('pm') | |
tree = BinOp(tree, Mult(), sqrt_node) | |
tree = BinOp(Constant(-cs[1]), Add(), tree) if cs[1] else tree | |
return BinOp(tree, Div(), Constant(2*cs[2])) | |
def roots(mv): | |
deg = max([max([e[1] for e in k], default = 0) | |
for k in mv if mv[k]], default = 0) | |
assert 1 <= deg <= 2, 'Degree 1 or 2 only' | |
by_var = defaultdict(lambda: defaultdict(int)) | |
for var_pows, c in mv.items(): | |
for v, p in var_pows: | |
by_var[v][p] = c | |
by_var = list(by_var.items()) | |
assert len(by_var) <= 1, 'Univariate polynomials only' | |
name, cs = by_var[0] | |
cs[0] = mv[()] | |
expr = roots_1st(cs) if deg == 1 else roots_2nd(cs) | |
return Compare(Name(name), [Eq()], [expr]) | |
tree = parse(sys.argv[2]).body[0].value | |
if type(tree) == Compare: | |
tree = BinOp(tree.left, Sub(), tree.comparators[0]) | |
mv = eval(tree) | |
print('==>', unparse(roots(mv) if sys.argv[1] == 'solve' else parse_mv(mv))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment