Skip to content

Instantly share code, notes, and snippets.

@bjourne
Last active July 26, 2022 20:03
Show Gist options
  • Save bjourne/e3a5fc633e37114f73b35777db4eeb43 to your computer and use it in GitHub Desktop.
Save bjourne/e3a5fc633e37114f73b35777db4eeb43 to your computer and use it in GitHub Desktop.
Computer algebra in 99 lines of Python
# Usage:
# python cas.py simplify "(x+y)**30*(2*x)+123*z**8"
# python cas.py solve "3*x**2 == -x**2 + 5*x"
import math, sys
from ast import *
from functools import reduce
from collections import Counter, defaultdict
def parse_mv(mv):
def pow(n, e):
return Name(n) if e == 1 else BinOp(Name(n), Pow(), Constant(e))
def term(es, c, first = False):
fs = [pow(n, e) for n, e in es]
fs = [Constant(abs(c))] + fs if abs(c) != 1 else fs
fs[0] = UnaryOp(USub(), fs[0]) if first and c < 0 else fs[0]
return reduce(lambda x, y: BinOp(x, Mult(), y), fs[1:], fs[0])
mv = list(reversed(sorted([(es, c) for (es, c) in mv.items() if c])))
ret = term(*mv[0], first = True) if mv else Constant(0)
join = {True : Add(), False : Sub()}
return reduce(lambda ret, el: BinOp(ret, join[el[1] > 0], term(*el)),
mv[1:], ret)
def add(mv1, mv2):
return defaultdict(int, {c : mv1[c] + mv2[c]
for c in list(mv1) + list(mv2)})
def sub(mv1, mv2):
return defaultdict(int, {c : mv1[c] - mv2[c]
for c in list(mv1) + list(mv2)})
def pow(mv1, mv2):
items = list(mv2.items())
assert len(items) == 1 and items[0][0] == () and items[0][1] > 0, \
'Non-negative integer powers only'
mv3 = defaultdict(int, {() : 1})
return reduce(lambda x, y: mul(x, mv1), range(items[0][1]), mv3)
def mul(mv1, mv2):
mv3 = defaultdict(int)
for es1, c1 in mv1.items():
for es2, c2 in mv2.items():
es3 = Counter(dict(es1))
es3.update(dict(es2))
es3 = tuple(sorted(es3.items()))
mv3[es3] += c1*c2
return mv3
def eval(tree):
tp = type(tree)
if tp == BinOp:
BINOPS = {Add : add, Sub : sub, Mult : mul, Pow : pow}
l, r = eval(tree.left), eval(tree.right)
return BINOPS[type(tree.op)](l, r)
elif tp == Name:
return defaultdict(int, {((tree.id, 1),) : 1})
elif tp == Constant:
return defaultdict(int, {() : tree.value})
elif tp == UnaryOp:
mv = eval(tree.operand)
return mul(mv, defaultdict(int, {() : -1}))
def roots_1st(cs):
q = -cs[0] / cs[1]
if q == int(q):
return Constant(int(q))
return BinOp(Constant(-cs[0]), Div(), Constant(cs[1]))
def roots_2nd(cs):
d = cs[1]**2 - 4*cs[2]*cs[0]
da, sqrt_da = abs(d), math.sqrt(abs(d))
sqrt_node = Call(Name('sqrt'), [Constant(da)], [])
if sqrt_da == int(sqrt_da):
sqrt_node = Constant(int(sqrt_da))
tree = BinOp(Name('pm'), Mult(), Name('i')) if d < 0 else Name('pm')
tree = BinOp(tree, Mult(), sqrt_node)
tree = BinOp(Constant(-cs[1]), Add(), tree) if cs[1] else tree
return BinOp(tree, Div(), Constant(2*cs[2]))
def roots(mv):
deg = max([max([e[1] for e in k], default = 0)
for k in mv if mv[k]], default = 0)
assert 1 <= deg <= 2, 'Degree 1 or 2 only'
by_var = defaultdict(lambda: defaultdict(int))
for var_pows, c in mv.items():
for v, p in var_pows:
by_var[v][p] = c
by_var = list(by_var.items())
assert len(by_var) <= 1, 'Univariate polynomials only'
name, cs = by_var[0]
cs[0] = mv[()]
expr = roots_1st(cs) if deg == 1 else roots_2nd(cs)
return Compare(Name(name), [Eq()], [expr])
tree = parse(sys.argv[2]).body[0].value
if type(tree) == Compare:
tree = BinOp(tree.left, Sub(), tree.comparators[0])
mv = eval(tree)
print('==>', unparse(roots(mv) if sys.argv[1] == 'solve' else parse_mv(mv)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment