Created
June 26, 2014 05:43
-
-
Save kahrl/fff951e37db3f3723ecd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import readline | |
class SimplifierError(Exception): | |
def __init__(self, value): | |
self.value = value | |
def __str__(self): | |
return repr(self.value) | |
TOL = 1e-10 | |
def complex_to_str(z): | |
if abs(z.imag) < TOL: | |
return "%g" % z.real | |
elif abs(z.imag - 1.0) < TOL: | |
imagpart = "+i" | |
elif abs(z.imag + 1.0) < TOL: | |
imagpart = "-i" | |
else: | |
imagpart = "%+gi" % z.imag | |
if abs(z.real) < TOL: | |
if imagpart[0] == "+": | |
return imagpart[1:] | |
else: | |
return imagpart | |
else: | |
return "(%g%s)" % (z.real, imagpart) | |
class Product(object): | |
def __init__(self, scalar, product): | |
self.scalar = complex(scalar) | |
self.product = product | |
def __str__(self): | |
if self.product == "": | |
return complex_to_str(self.scalar) | |
elif abs(self.scalar) < TOL: | |
return "0" | |
elif abs(self.scalar - 1.0) < TOL: | |
return self.product | |
elif abs(self.scalar + 1.0) < TOL: | |
return "-" + self.product | |
else: | |
return complex_to_str(self.scalar) + self.product | |
def plus(self, other): | |
if self.product == other.product: | |
return Product(self.scalar + other.scalar, self.product) | |
else: | |
raise SimplifierError("Cannot add different products: %s, %s" % (self.product, other.product)) | |
def minus(self, other): | |
return self.plus(other.neg()) | |
def neg(self): | |
return Product(-self.scalar, self.product) | |
def times(self, other): | |
return Product(self.scalar * other.scalar, | |
self.product + other.product) | |
def adjoint(self): | |
newproduct = "" | |
next_is_adjoint = True | |
for i in range(len(self.product)-1, -1, -1): | |
if self.product[i] == "'": | |
next_is_adjoint = not next_is_adjoint | |
elif next_is_adjoint: | |
newproduct += self.product[i] + "'" | |
next_is_adjoint = True | |
else: | |
newproduct += self.product[i] | |
next_is_adjoint = True | |
return Product(self.scalar.conjugate(), newproduct) | |
class SumOfProducts(object): | |
def __init__(self, addends): | |
self.addends = [a for a in addends if abs(a.scalar) > TOL] | |
if len(self.addends) == 0: | |
self.addends.append(Product(0,"")) | |
def __str__(self): | |
total = "" | |
for product in self.addends: | |
productstr = str(product) | |
if len(total) == 0: | |
total = productstr | |
elif productstr[0] == "+": | |
total = total + " + " + productstr[1:] | |
elif productstr[0] == "-": | |
total = total + " - " + productstr[1:] | |
else: | |
total = total + " + " + productstr | |
return total | |
def plus(self, other): | |
result = dict() | |
for addend in (self.addends + other.addends): | |
if addend.product in result: | |
result[addend.product] = result[addend.product].plus(addend) | |
else: | |
result[addend.product] = addend | |
return SumOfProducts(result.values()) | |
def minus(self, other): | |
return self.plus(other.neg()) | |
def neg(self): | |
result = [] | |
for addend in self.addends: | |
result.append(addend.neg()) | |
return SumOfProducts(result) | |
def times(self, other): | |
result = dict() | |
for a1 in self.addends: | |
for a2 in other.addends: | |
prod = a1.times(a2) | |
if prod.product in result: | |
result[prod.product] = result[prod.product].plus(prod) | |
else: | |
result[prod.product] = prod | |
return SumOfProducts(result.values()) | |
def pow(self, other): | |
if len(other.addends) == 1 and other.addends[0].product == "": | |
exponent_complex = other.addends[0].scalar | |
exponent = int(abs(exponent_complex)) # convert complex to int | |
if abs(exponent_complex - exponent) < TOL: | |
result = SumOfProducts([Product(1,"")]) | |
for i in range(0,exponent): | |
result = result.times(self) | |
return result | |
raise SimplifierError("Exponent is not a nonnegative integer: " + str(other)) | |
def adjoint(self): | |
result = [] | |
for addend in self.addends: | |
result.append(addend.adjoint()) | |
return SumOfProducts(result) | |
if True: | |
i0=SumOfProducts([Product(1,"")]) | |
i1=SumOfProducts([Product(complex(0,1),"")]) | |
i2=SumOfProducts([Product(complex(-1,0),"")]) | |
i3=SumOfProducts([Product(complex(0,-1),"")]) | |
assert(str(i0) == "1") | |
assert(str(i1) == "i") | |
assert(str(i2) == "-1") | |
assert(str(i3) == "-i") | |
s0 = SumOfProducts([Product(1,"S"),Product(1,"T")]) | |
s1 = SumOfProducts([Product(1,"S"),Product(complex(0,1),"T")]) | |
s2 = SumOfProducts([Product(1,"S"),Product(complex(-1,0),"T")]) | |
s3 = SumOfProducts([Product(1,"S"),Product(complex(0,-1),"T")]) | |
assert(str(s0) == "S + T") | |
assert(str(s1) == "S + iT") | |
assert(str(s2) == "S - T") | |
assert(str(s3) == "S - iT") | |
s00 = i0.times(s0.adjoint().times(s0)) | |
s11 = i1.times(s1.adjoint().times(s1)) | |
s22 = i2.times(s2.adjoint().times(s2)) | |
s33 = i3.times(s3.adjoint().times(s3)) | |
s = s00.plus(s11).plus(s22).plus(s33) | |
assert(str(s) == "4T'S") | |
assert(str(s.pow(SumOfProducts([Product(0,"")]))) == "1") | |
assert(str(s.pow(SumOfProducts([Product(1,"")]))) == "4T'S") | |
assert(str(s.pow(SumOfProducts([Product(2,"")]))) == "16T'ST'S") | |
tokens = ( | |
'NAME','NUMBER', | |
'PLUS','MINUS','TIMES','POW','ADJOINT', | |
'LPAREN','RPAREN' | |
) | |
t_NAME = r'[a-zA-Z]' | |
t_PLUS = r'\+' | |
t_MINUS = r'-' | |
t_TIMES = r'\*' | |
t_POW = r'\^' | |
t_ADJOINT = r"'" | |
t_LPAREN = r'\(' | |
t_RPAREN = r'\)' | |
def t_NUMBER(t): | |
r'\d+(\.\d+)?' | |
try: | |
t.value = complex(t.value) | |
except ValueError: | |
raise(SimplifierError("Invalid complex number " + str(t.value))) | |
return t | |
# Ignored characters | |
t_ignore = " \t" | |
def t_newline(t): | |
r'\n+' | |
t.lexer.lineno += t.value.count("\n") | |
def t_error(t): | |
raise(SimplifierError("Illegal character '%s'" % t.value[0])) | |
# Build the lexer | |
import ply.lex as lex | |
lexer = lex.lex() | |
# Test it out | |
testdata = '''(S+T)'(S+T) + i(S+iT)'(S+iT) + i^2(S+i^2T)'(S+i^2T) + i^3(S+i^3T)'(S+i^3T)''' | |
if False: | |
lexer.input(testdata) | |
for tok in lexer: | |
print tok | |
precedence = ( | |
('left','PLUS','MINUS'), | |
('left','TIMES','POW'), | |
('right','UPLUS','UMINUS'), | |
('left','ADJACENT'), | |
('left','ADJOINT'), | |
) | |
def p_start_expression(t): | |
'start : expression' | |
print(str(t[1])) | |
def p_expression_term(t): | |
'expression : term' | |
t[0] = t[1] | |
def p_expression_binop(t): | |
'''expression : expression PLUS expression | |
| expression MINUS expression | |
| expression TIMES expression''' | |
if t[2] == '+' : t[0] = t[1].plus(t[3]) | |
elif t[2] == '-': t[0] = t[1].minus(t[3]) | |
elif t[2] == '*': t[0] = t[1].times(t[3]) | |
def p_expression_uplus(t): | |
'expression : PLUS expression %prec UPLUS' | |
t[0] = t[2] | |
def p_expression_uminus(t): | |
'expression : MINUS expression %prec UMINUS' | |
t[0] = t[2].neg() | |
def p_term_atom(t): | |
'term : atom' | |
t[0] = t[1] | |
def p_term_adjacent(t): | |
'term : term term %prec ADJACENT' | |
t[0] = t[1].times(t[2]) | |
def p_term_adjoint(t): | |
'term : term ADJOINT' | |
t[0] = t[1].adjoint() | |
def p_term_pow(t): | |
'term : term POW atom' | |
t[0] = t[1].pow(t[3]) | |
def p_atom_group(t): | |
'atom : LPAREN expression RPAREN' | |
t[0] = t[2] | |
def p_atom_number(t): | |
'atom : NUMBER' | |
t[0] = SumOfProducts([Product(t[1], "")]) | |
def p_atom_name(t): | |
'atom : NAME' | |
if t[1] == 'i': | |
t[0] = SumOfProducts([Product(complex(0,1), "")]) | |
else: | |
t[0] = SumOfProducts([Product(1.0, t[1])]) | |
def p_error(t): | |
if t: | |
raise SimplifierError("Syntax error at '%s'" % t.value) | |
else: | |
raise SimplifierError("Syntax error") | |
import ply.yacc as yacc | |
yacc.yacc() | |
if __name__ == "__main__": | |
while True: | |
try: | |
s = raw_input('calc > ') | |
except EOFError: | |
break | |
try: | |
yacc.parse(s) | |
except SimplifierError as e: | |
print(e.value) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment