Skip to content

Instantly share code, notes, and snippets.

@kahrl
Created June 26, 2014 05:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kahrl/fff951e37db3f3723ecd to your computer and use it in GitHub Desktop.
Save kahrl/fff951e37db3f3723ecd to your computer and use it in GitHub Desktop.
import readline
class SimplifierError(Exception):
def __init__(self, value):
self.value = value
def __str__(self):
return repr(self.value)
TOL = 1e-10
def complex_to_str(z):
if abs(z.imag) < TOL:
return "%g" % z.real
elif abs(z.imag - 1.0) < TOL:
imagpart = "+i"
elif abs(z.imag + 1.0) < TOL:
imagpart = "-i"
else:
imagpart = "%+gi" % z.imag
if abs(z.real) < TOL:
if imagpart[0] == "+":
return imagpart[1:]
else:
return imagpart
else:
return "(%g%s)" % (z.real, imagpart)
class Product(object):
def __init__(self, scalar, product):
self.scalar = complex(scalar)
self.product = product
def __str__(self):
if self.product == "":
return complex_to_str(self.scalar)
elif abs(self.scalar) < TOL:
return "0"
elif abs(self.scalar - 1.0) < TOL:
return self.product
elif abs(self.scalar + 1.0) < TOL:
return "-" + self.product
else:
return complex_to_str(self.scalar) + self.product
def plus(self, other):
if self.product == other.product:
return Product(self.scalar + other.scalar, self.product)
else:
raise SimplifierError("Cannot add different products: %s, %s" % (self.product, other.product))
def minus(self, other):
return self.plus(other.neg())
def neg(self):
return Product(-self.scalar, self.product)
def times(self, other):
return Product(self.scalar * other.scalar,
self.product + other.product)
def adjoint(self):
newproduct = ""
next_is_adjoint = True
for i in range(len(self.product)-1, -1, -1):
if self.product[i] == "'":
next_is_adjoint = not next_is_adjoint
elif next_is_adjoint:
newproduct += self.product[i] + "'"
next_is_adjoint = True
else:
newproduct += self.product[i]
next_is_adjoint = True
return Product(self.scalar.conjugate(), newproduct)
class SumOfProducts(object):
def __init__(self, addends):
self.addends = [a for a in addends if abs(a.scalar) > TOL]
if len(self.addends) == 0:
self.addends.append(Product(0,""))
def __str__(self):
total = ""
for product in self.addends:
productstr = str(product)
if len(total) == 0:
total = productstr
elif productstr[0] == "+":
total = total + " + " + productstr[1:]
elif productstr[0] == "-":
total = total + " - " + productstr[1:]
else:
total = total + " + " + productstr
return total
def plus(self, other):
result = dict()
for addend in (self.addends + other.addends):
if addend.product in result:
result[addend.product] = result[addend.product].plus(addend)
else:
result[addend.product] = addend
return SumOfProducts(result.values())
def minus(self, other):
return self.plus(other.neg())
def neg(self):
result = []
for addend in self.addends:
result.append(addend.neg())
return SumOfProducts(result)
def times(self, other):
result = dict()
for a1 in self.addends:
for a2 in other.addends:
prod = a1.times(a2)
if prod.product in result:
result[prod.product] = result[prod.product].plus(prod)
else:
result[prod.product] = prod
return SumOfProducts(result.values())
def pow(self, other):
if len(other.addends) == 1 and other.addends[0].product == "":
exponent_complex = other.addends[0].scalar
exponent = int(abs(exponent_complex)) # convert complex to int
if abs(exponent_complex - exponent) < TOL:
result = SumOfProducts([Product(1,"")])
for i in range(0,exponent):
result = result.times(self)
return result
raise SimplifierError("Exponent is not a nonnegative integer: " + str(other))
def adjoint(self):
result = []
for addend in self.addends:
result.append(addend.adjoint())
return SumOfProducts(result)
if True:
i0=SumOfProducts([Product(1,"")])
i1=SumOfProducts([Product(complex(0,1),"")])
i2=SumOfProducts([Product(complex(-1,0),"")])
i3=SumOfProducts([Product(complex(0,-1),"")])
assert(str(i0) == "1")
assert(str(i1) == "i")
assert(str(i2) == "-1")
assert(str(i3) == "-i")
s0 = SumOfProducts([Product(1,"S"),Product(1,"T")])
s1 = SumOfProducts([Product(1,"S"),Product(complex(0,1),"T")])
s2 = SumOfProducts([Product(1,"S"),Product(complex(-1,0),"T")])
s3 = SumOfProducts([Product(1,"S"),Product(complex(0,-1),"T")])
assert(str(s0) == "S + T")
assert(str(s1) == "S + iT")
assert(str(s2) == "S - T")
assert(str(s3) == "S - iT")
s00 = i0.times(s0.adjoint().times(s0))
s11 = i1.times(s1.adjoint().times(s1))
s22 = i2.times(s2.adjoint().times(s2))
s33 = i3.times(s3.adjoint().times(s3))
s = s00.plus(s11).plus(s22).plus(s33)
assert(str(s) == "4T'S")
assert(str(s.pow(SumOfProducts([Product(0,"")]))) == "1")
assert(str(s.pow(SumOfProducts([Product(1,"")]))) == "4T'S")
assert(str(s.pow(SumOfProducts([Product(2,"")]))) == "16T'ST'S")
tokens = (
'NAME','NUMBER',
'PLUS','MINUS','TIMES','POW','ADJOINT',
'LPAREN','RPAREN'
)
t_NAME = r'[a-zA-Z]'
t_PLUS = r'\+'
t_MINUS = r'-'
t_TIMES = r'\*'
t_POW = r'\^'
t_ADJOINT = r"'"
t_LPAREN = r'\('
t_RPAREN = r'\)'
def t_NUMBER(t):
r'\d+(\.\d+)?'
try:
t.value = complex(t.value)
except ValueError:
raise(SimplifierError("Invalid complex number " + str(t.value)))
return t
# Ignored characters
t_ignore = " \t"
def t_newline(t):
r'\n+'
t.lexer.lineno += t.value.count("\n")
def t_error(t):
raise(SimplifierError("Illegal character '%s'" % t.value[0]))
# Build the lexer
import ply.lex as lex
lexer = lex.lex()
# Test it out
testdata = '''(S+T)'(S+T) + i(S+iT)'(S+iT) + i^2(S+i^2T)'(S+i^2T) + i^3(S+i^3T)'(S+i^3T)'''
if False:
lexer.input(testdata)
for tok in lexer:
print tok
precedence = (
('left','PLUS','MINUS'),
('left','TIMES','POW'),
('right','UPLUS','UMINUS'),
('left','ADJACENT'),
('left','ADJOINT'),
)
def p_start_expression(t):
'start : expression'
print(str(t[1]))
def p_expression_term(t):
'expression : term'
t[0] = t[1]
def p_expression_binop(t):
'''expression : expression PLUS expression
| expression MINUS expression
| expression TIMES expression'''
if t[2] == '+' : t[0] = t[1].plus(t[3])
elif t[2] == '-': t[0] = t[1].minus(t[3])
elif t[2] == '*': t[0] = t[1].times(t[3])
def p_expression_uplus(t):
'expression : PLUS expression %prec UPLUS'
t[0] = t[2]
def p_expression_uminus(t):
'expression : MINUS expression %prec UMINUS'
t[0] = t[2].neg()
def p_term_atom(t):
'term : atom'
t[0] = t[1]
def p_term_adjacent(t):
'term : term term %prec ADJACENT'
t[0] = t[1].times(t[2])
def p_term_adjoint(t):
'term : term ADJOINT'
t[0] = t[1].adjoint()
def p_term_pow(t):
'term : term POW atom'
t[0] = t[1].pow(t[3])
def p_atom_group(t):
'atom : LPAREN expression RPAREN'
t[0] = t[2]
def p_atom_number(t):
'atom : NUMBER'
t[0] = SumOfProducts([Product(t[1], "")])
def p_atom_name(t):
'atom : NAME'
if t[1] == 'i':
t[0] = SumOfProducts([Product(complex(0,1), "")])
else:
t[0] = SumOfProducts([Product(1.0, t[1])])
def p_error(t):
if t:
raise SimplifierError("Syntax error at '%s'" % t.value)
else:
raise SimplifierError("Syntax error")
import ply.yacc as yacc
yacc.yacc()
if __name__ == "__main__":
while True:
try:
s = raw_input('calc > ')
except EOFError:
break
try:
yacc.parse(s)
except SimplifierError as e:
print(e.value)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment