Skip to content

Instantly share code, notes, and snippets.

@worldbeater
Last active April 29, 2024 00:10
Show Gist options
  • Save worldbeater/da12e268babb0f1e088a7df1a98307a9 to your computer and use it in GitHub Desktop.
Save worldbeater/da12e268babb0f1e088a7df1a98307a9 to your computer and use it in GitHub Desktop.
Simple term rewriting system (TRS) that is based on structural pattern matching, see https://peps.python.org/pep-0636 and https://inst.eecs.berkeley.edu/~cs294-260/sp24/2024-01-22-term-rewriting
def topdown(rule, expr):
match rule(expr):
case (spec, *args):
return (spec, *(topdown(rule, arg) for arg in args))
case expr:
return expr
def rewrite(rule, expr):
while (new_expr := topdown(rule, expr)) != expr:
expr = new_expr
return new_expr
def derive(expr):
match expr:
case ('∂', ('add', u, v), var):
return ('add', ('∂', u, var), ('∂', v, var))
case ('∂', ('pow', var0, int(n)), var) if var == var0:
return ('mul', n, ('pow', var, n - 1))
case ('∂', ('div', 1, var0), var) if var == var0:
return ('∂', ('pow', var, -1), var)
case expr:
return expr
def simplify(expr):
match expr:
case ('add', a, ('unm', b)) | ('add', ('unm', b), a):
return ('sub', a, b)
case ('mul', -1, var) | ('mul', var, -1):
return ('unm', var)
case ('pow', var, 1):
return var
case expr:
return expr
OPS = {
'add': '{%s}+{%s}',
'sub': '{%s}-{%s}',
'div': '\\frac{%s}{%s}',
'mul': '{%s}{%s}',
'pow': '{%s}^{%s}',
'∂': '\\frac{\\partial{\\left(%s\\right)}}{\\partial{%s}}',
}
def tex(tree):
match tree:
case (spec, *args):
return OPS[spec] % tuple(map(tex, args))
case atom:
return atom
expr = ('∂', ('add', ('pow', 'x', 2), ('div', 1, 'x')), 'x')
derv = rewrite(derive, expr)
simp = rewrite(simplify, derv)
print(expr) # Original expression.
print(derv) # Derivative of the expression.
print(simp) # Simplified derivative of the expression.
print(tex(expr), '=', tex(simp)) # LaTeX markup.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment