Parse simple(ish) algorithms from whitepapers, do basic optimisations to transform into consistent state
import re
from pypeg2 import word, attr, maybe_some, blank, endl, parse, optional
from collections import defaultdict
class ExprBase(str):
class Expression(ExprBase):
def __str__(self):
return str(self.inside) + ''.join([str(_) for _ in self.rhs])
class Group(ExprBase):
def __str__(self):
return '(' + str(self.inside) + ')'
class Powerable(ExprBase):
grammar = attr('inside', [word, Group]), attr('power', optional('^', re.compile(r'[1-9][0-9]*')))
def __str__(self):
res = str(self.inside)
if self.power:
res += '^' + self.power
return res
def is_num(x):
if len(x):
if x[0] == '-':
return x[1:].isdigit()
return x.isdigit()
return False
class OpExpr(ExprBase):
grammar = attr('op', re.compile(r'[·\+\-\*−/]')), attr('rhs', [Powerable, re.compile(r'-?[1-9][0-9]*')])
def __str__(self):
return f' {self.op} {self.rhs}'
Group.grammar = '(', attr('inside', Expression), ')'
Expression.grammar = attr('inside', [Powerable, re.compile(r'-?[1-9][0-9]*')]), attr('rhs', maybe_some(OpExpr))
class Assignment(ExprBase):
grammar = attr('output', word), blank, ['=', '←'], blank, attr('expr', Expression), endl
def __str__(self):
return f'{self.output} = {self.expr}'
def postprocess(stmts, outputs, tmpvars):
# Convert variables to SSA form first
while True:
count = 0
assignments = defaultdict(list)
for i, row in enumerate(stmts):
result = list()
replace_with = None
skipall = False
for i, row in enumerate(stmts):
row = list(row)
if not skipall:
if replace_with is None:
if len(assignments[row[0]]) > 0 and i != assignments[row[0]][-1]:
tmpvar = f'tmp{len(tmpvars)}'
replace_with = (row[0], tmpvar, assignments[row[0]][1])
row[0] = tmpvar
count += 1
if row[0] == replace_with[0]:
skipall = True
row[2:] = [replace_with[1] if _ == replace_with[0] else _ for _ in row[2:]]
stmts = result
if count == 0:
# Then perform expression deduplication
# we can freely change variable names here via a rewrite mapping
exprcache = dict()
cacherewrite = dict()
results = list()
for i, row in enumerate(stmts):
rhs = tuple(row[1:])
if row[1] != '=':
if rhs in exprcache:
# If statement has already been computed
# Change to an assignment... this will be optimised out later
row = [row[0], '=', exprcache[rhs]]
exprcache[rhs] = row[0]
stmts = results
# Iterate through statements in reverse, removing needless copies
while True:
# Count number of times each variable is assigned
assignments = defaultdict(int)
for row in stmts:
assignments[row[0]] += 1
used_once = set([a for a, b in assignments.items() if b == 1])
count = 0
result = list()
replace_with = None
for row in stmts[::-1]:
row = list(row)
if replace_with is None:
if row[1] == '=':
if row[2] in used_once:
replace_with = (row[0], row[2])
if row[0] == replace_with[1]:
row[0] = replace_with[0]
count += 1
row[2:] = [replace_with[0] if _ == replace_with[1] else _ for _ in row[2:]]
stmts = result[::-1]
if count == 0:
# Determine which variables are used as inputs
reads = defaultdict(int)
for row in stmts:
for var in row[2:]:
reads[var] += 1
never_written = set([a for a, b in reads.items() if assignments[a] == 0 or a not in assignments])
ignored = list()
while True:
# Find variables which can be overwritten later in the program
count = 0
count_before = len(set([row[0] for row in stmts]))
last_read = defaultdict(int)
for i, row in enumerate(stmts):
for var in row[2:]:
if var not in never_written and var not in outputs:
last_read[var] = i
# Then perform variable re-assignment, re-using variables
# TODO: verify we don't overwrite output variables
result = list()
replace_with = None
for i, row in enumerate(stmts):
free_vars = ([a for a, b in last_read.items() if b <= i])
if replace_with is None:
if row[1][0] != '_' and len(free_vars) and row[0] not in outputs and i not in ignored:
replace_with = (row[0], free_vars[0], i)
row[0] = free_vars[0]
count += 1
row = row[:2] + [replace_with[1] if _ == replace_with[0] else _ for _ in row[2:]]
stmts = result
count_after = len(set([row[0] for row in stmts]))
if count_before == count_after:
# Stop if we fail to reduce the var count
# Avoids getting stuck in loops shuffling variable names around
if replace_with is not None:
if count == 0:
# Turn statements into '_inplace'
result = list()
for row in stmts:
if row[1] != '=':
if len(row) == 4:
if row[0] in row[2:]:
row = row[:2] + [_ for _ in row[2:] if _ != row[0]]
row[1] = row[1] + '_inplace'
elif len(row) == 3:
if row[0] == row[2]:
row[1] = row[1] + '_inplace'
row = row[:2]
stmts = result
return stmts
def analyze(lines, outputs):
processed = list()
tmpvars = list()
def iter_expr(expr):
if not isinstance(expr, ExprBase):
# Strings will passthru
return expr
ops = {
'·': 'mul',
'*': 'mul',
'-': 'sub',
'−': 'sub', # Allows for easier copy-pasta from Latex PDFs
'+': 'add',
'/': 'div'
if isinstance(expr, (Expression, Group)):
lhs = iter_expr(expr.inside)
if isinstance(expr, Powerable):
if expr.power is not None:
power_arg = iter_expr(expr.inside)
tmpvar = f'tmp{len(tmpvars)}'
if expr.power == '2':
processed.append((tmpvar, 'square', power_arg))
elif expr.power == '3':
processed.append((tmpvar, 'square', power_arg))
prev_tmpvar = tmpvar
tmpvar = f'tmp{len(tmpvars)}'
processed.append((tmpvar, 'mul', prev_tmpvar, power_arg))
processed.append((tmpvar, 'power', power_arg, int(expr.power)))
lhs = tmpvar
# Powerable, without a power, just acts like a normal group
lhs = iter_expr(expr.inside)
if hasattr(expr, 'rhs'):
for item in expr.rhs:
tmpvar = f'tmp{len(tmpvars)}'
if isinstance(item, Powerable):
rhs_arg = iter_expr(item)
elif isinstance(item, Group):
rhs_arg = iter_expr(item.inside)
elif isinstance(item, OpExpr):
rhs_arg = item.rhs
pending = [
pending[2:] = [int(_) if is_num(_) else _ for _ in pending[2:]]
has_int_args = len([_ for _ in pending[2:] if isinstance(_, int)])
if has_int_args:
int_arg = [_ for _ in pending[2:] if isinstance(_, int)][0]
nonint_arg = [_ for _ in pending[2:] if not isinstance(_, int)][0]
if pending[1] == 'div':
if int_arg == 1:
pending[1] = 'inverse'
pending[2] = nonint_arg
pending = pending[:3]
tmpvar2 = f'_const{len(tmpvars)}'
processed.append([tmpvar2, '_invert_const', int_arg])
pending = [tmpvar, 'mul_const', nonint_arg, tmpvar2]
has_int_args = False
if pending[1] in ['mul','add','sub']:
if has_int_args:
if pending[1] == 'mul' and int_arg > 0 and bin(int_arg)[2:].count('1') == 1:
# Exact power of 2, allows optimisation of doubling-inplace
pending[1] = f'{pending[1]}_{int_arg}'
pending[2] = nonint_arg
pending = pending[:3]
elif has_int_args == 1:
# One argument is a short integer, further optimisation
pending[1] += '_const'
pending[2] = nonint_arg
pending[3] = int_arg
lhs = tmpvar
return lhs
for line in lines:
if isinstance(line, Assignment):
expr = line.expr
rhs = iter_expr(expr)
if rhs == '1':
processed.append((line.output, 'set_one'))
elif rhs == '0':
processed.append((line.output, 'set_zero'))
processed.append((line.output, '=', rhs))
raise RuntimeError("Unknown type")
# Display before preprocessing
for i, row in enumerate(processed):
print(i, row)
processed = postprocess(processed, outputs, tmpvars)
for i, row in enumerate(processed):
print(i, row)
if __name__ == "__main__":
code = """
A = (X2-X1)^2
B = X1*A
C = X2*A
D = (Y2-Y1)^2
X3 = D-B-C
Y3 = (Y2-Y1)*(B-X3)-Y1*(C-B)
Z3 = Z1*(X2-X1)
code = """
A = 1/Z1
AA = A^2
X3 = X1*AA
Y3 = Y1*AA*A
Z3 = 1
code = """
XX = X1^2
YY = Y1^2
ZZ = Z1^2
M = 3*XX+a*ZZ^2
MM = M^2
E = 6*((X1+YY)^2-XX-YYYY)-MM
EE = E^2
T = 16*YYYY
U = (M+E)^2-MM-EE-T
X3 = 4*(X1*EE-4*YY*U)
Y3 = 8*Y1*(U*(T-U)-E*EE)
Z3 = (Z1+E)^2-ZZ-EE
code = """
S = 4*X1*Y1^2
M = 3*X1^2+a*Z1^4
T = M2-2*S
X3 = T
Y3 = M*(S-T)-8*Y1^4
Z3 = 2*Y1*Z1
code = """
t0 ← X1 · X2
t1 ← Y1 · Y2
t2 ← Z1 · Z2
t3 ← X1 + Y1
t4 ← X2 + Y2
t3 ← t3 · t4
t4 ← t0 + t1
t3 ← t3 − t4
t4 ← X1 + Z1
t5 ← X2 + Z2
t4 ← t4 · t5
t5 ← t0 + t2
t4 ← t4 − t5
t5 ← Y1 + Z1
X3 ← Y2 + Z2
t5 ← t5 · X3
X3 ← t1 + t2
t5 ← t5 − X3
Z3 ← a · t4
X3 ← b3 · t2
Z3 ← X3 + Z3
X3 ← t1 − Z3
Z3 ← t1 + Z3
Y3 ← X3 · Z3
t1 ← t0 + t0
t1 ← t1 + t0
t2 ← a · t2
t4 ← b3 · t4
t1 ← t1 + t2
t2 ← t0 − t2
t2 ← a · t2
t4 ← t4 + t2
t0 ← t1 · t4
Y3 ← Y3 + t0
t0 ← t5 · t4
X3 ← t3 · X3
X3 ← X3 − t0
t0 ← t3 · t1
Z3 ← t5 · Z3
Z3 ← Z3 + t0"""
code = """
U1 = X1*Z2^2
U2 = X2*Z1^2
S1 = Y1*Z2^3
S2 = Y2*Z1^3
P = U2-U1
R = S2-S1
X3 = R^2-(U1+U2)*P^2
Y3 = (R*(-2*R^2+3*P^2*(U1+U2))-P^3*(S1+S2))/2
Z3 = Z1*Z2*P
code = """
A = Z1^2
B = Z2^2
C = (Z1+Z2)^2-A-B
D = X1*Z2
E = X2*Z1
F = Y1*B
G = Y2*A
H = D-E
I = 2*(F-G)
II = I^2
J = C*H
K = 4*J*H
X3 = 2*II-(D+E)*K
JJ = J^2
Y3 = ((J+I)^2-JJ-II)*(D*K-X3)-F*K^2
Z3 = 2*JJ
stmts = list()
for line in code.split("\n"):
line = line.strip()
if not len(line):
if line[0] == '#':
result = parse(line, Assignment)
analyze(stmts, ['X3', 'Y3', 'Z3'])
