Last active
March 1, 2017 21:21
-
-
Save leerobert/e8a65492b6e70e92997510dc0f1debf3 to your computer and use it in GitHub Desktop.
Checking for sympy expr equality for evaluate=False
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sympy import sympify, default_sort_key | |
def struct_eq(expr1, expr2): | |
''' | |
Tests that two expressions are structurally equal without regard to the evaluate flag. | |
This is done by manually recursing through all the args until we get to a base case. | |
When building an expression with the evaluate flag, args are not ordered | |
so we must manually check the struture without regard to argument order. | |
''' | |
if expr1 == expr2: # if it passes usual test, return true to save time | |
return True | |
if expr1.func != expr2.func: # top level func check... | |
return False | |
if expr1.is_Atom or expr2.is_Atom: # here is our base case... | |
return expr1 == expr2 | |
if len(expr1.args) != len(expr2.args): # need this check so we can zip | |
return False | |
sorted1 = sorted(expr1.args, key=default_sort_key) | |
sorted2 = sorted(expr2.args, key=default_sort_key) | |
return all(struct_eq(arg1, arg2) for arg1, arg2 in zip(sorted1, sorted2)) | |
if __name__ == '__main__': | |
e1, e2 = sympify('2*x + 3*x - 4', evaluate=False), sympify('3*x + 2*x - 4', evaluate=False) | |
struct_eq(e1, e2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment