Skip to content

Instantly share code, notes, and snippets.

@markdewing
Created March 31, 2011 14:14
Show Gist options
  • Save markdewing/896417 to your computer and use it in GitHub Desktop.
Save markdewing/896417 to your computer and use it in GitHub Desktop.
Possible syntax for pattern-matching of sympy syntax trees
#Prototype implementation of pattern-matching
class AutoVarInstance(object):
def __init__(self, parent, name):
self.parent = parent
self.name = name
class AutoVar(object):
'''Access any member variable to return an AutoVarInstance object that can be used
to bind it to a value later'''
def __init__(self):
self.vars = []
def __getattr__(self,name):
self.vars.append(name)
return AutoVarInstance(self,name)
class Match(object):
def __init__(self, expr):
self.expr = expr
def type(self, value):
'''Match on the type of the expression type'''
return isinstance(type(self.expr), value)
def __call__(self, *args):
'''Match on first arg as an expression type, next args bind to expression args'''
match = True
#print 'args',len(args),args
if len(args) > 0:
match = isinstance(self.expr,args[0])
if match == False:
return False
if len(args) == 1:
return match
#print 'self.expr',self.expr
expr_args = self.expr.args
if len(args) == 3 and len(self.expr.args) > 2:
expr_args = self.expr.as_two_terms()
#print ' args:',expr_args
for a,e in zip(args[1:], expr_args):
if isinstance(a,tuple):
m = Match(e)
match &= m(*a)
elif isinstance(a, AutoVarInstance):
a.parent.__dict__[a.name] = e
else:
match &= a == e
if not match:
break
return match
from sympy import *
from lang_py import *
def expr_to_py(e):
'''Convert sympy expression to python syntax tree'''
v = AutoVar()
m = Match(e)
# subtraction
if m(Add, (Mul, S.NegativeOne, v.e1), v.e2):
return py_expr(py_expr.PY_OP_MINUS, expr_to_py(v.e2), expr_to_py(v.e1))
if m(Add, v.e1, v.e2):
return py_expr(py_expr.PY_OP_PLUS, expr_to_py(v.e1), expr_to_py(v.e2))
# reciprocal
if m(Pow, v.e2, S.NegativeOne):
return py_expr(py_expr.PY_OP_DIVIDE, py_num(1.0), expr_to_py(v.e2))
# division
if m(Mul, v.e1, (Pow, v.e2, S.NegativeOne)):
return py_expr(py_expr.PY_OP_DIVIDE, expr_to_py(v.e1), expr_to_py(v.e2))
if m(Mul, v.e1, v.e2):
return py_expr(py_expr.PY_OP_TIMES, expr_to_py(v.e1), expr_to_py(v.e2))
# function call
if m.type(FunctionClass):
return py_function_call(str(type(e)), *[expr_to_py(a) for a in e.args])
if m(Symbol):
return py_var(str(e))
if m(Integer):
return py_num(e.p)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment