Skip to content

Instantly share code, notes, and snippets.

@HarryR
Created October 23, 2019 13:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save HarryR/c3eee3a6ed009fd4f9772e98c2a24010 to your computer and use it in GitHub Desktop.
Save HarryR/c3eee3a6ed009fd4f9772e98c2a24010 to your computer and use it in GitHub Desktop.
Parse simple(ish) algorithms from whitepapers, do basic optimisations to transform into consistent state
import re
from pypeg2 import word, attr, maybe_some, blank, endl, parse, optional
from collections import defaultdict
class ExprBase(str):
pass
class Expression(ExprBase):
def __str__(self):
return str(self.inside) + ''.join([str(_) for _ in self.rhs])
class Group(ExprBase):
def __str__(self):
return '(' + str(self.inside) + ')'
class Powerable(ExprBase):
grammar = attr('inside', [word, Group]), attr('power', optional('^', re.compile(r'[1-9][0-9]*')))
def __str__(self):
res = str(self.inside)
if self.power:
res += '^' + self.power
return res
def is_num(x):
if len(x):
if x[0] == '-':
return x[1:].isdigit()
return x.isdigit()
return False
class OpExpr(ExprBase):
grammar = attr('op', re.compile(r'[·\+\-\*−/]')), attr('rhs', [Powerable, re.compile(r'-?[1-9][0-9]*')])
def __str__(self):
return f' {self.op} {self.rhs}'
Group.grammar = '(', attr('inside', Expression), ')'
Expression.grammar = attr('inside', [Powerable, re.compile(r'-?[1-9][0-9]*')]), attr('rhs', maybe_some(OpExpr))
class Assignment(ExprBase):
grammar = attr('output', word), blank, ['=', '←'], blank, attr('expr', Expression), endl
def __str__(self):
return f'{self.output} = {self.expr}'
def postprocess(stmts, outputs, tmpvars):
# Convert variables to SSA form first
while True:
count = 0
assignments = defaultdict(list)
for i, row in enumerate(stmts):
assignments[row[0]].append(i)
result = list()
replace_with = None
skipall = False
for i, row in enumerate(stmts):
row = list(row)
if not skipall:
if replace_with is None:
if len(assignments[row[0]]) > 0 and i != assignments[row[0]][-1]:
tmpvar = f'tmp{len(tmpvars)}'
tmpvars.append(tmpvar)
replace_with = (row[0], tmpvar, assignments[row[0]][1])
row[0] = tmpvar
count += 1
else:
if row[0] == replace_with[0]:
skipall = True
else:
row[2:] = [replace_with[1] if _ == replace_with[0] else _ for _ in row[2:]]
result.append(row)
stmts = result
if count == 0:
break
#"""
# Then perform expression deduplication
# we can freely change variable names here via a rewrite mapping
exprcache = dict()
cacherewrite = dict()
results = list()
for i, row in enumerate(stmts):
rhs = tuple(row[1:])
if row[1] != '=':
if rhs in exprcache:
# If statement has already been computed
# Change to an assignment... this will be optimised out later
row = [row[0], '=', exprcache[rhs]]
else:
exprcache[rhs] = row[0]
results.append(row)
stmts = results
#"""
# Iterate through statements in reverse, removing needless copies
while True:
# Count number of times each variable is assigned
assignments = defaultdict(int)
for row in stmts:
assignments[row[0]] += 1
used_once = set([a for a, b in assignments.items() if b == 1])
count = 0
result = list()
replace_with = None
for row in stmts[::-1]:
row = list(row)
if replace_with is None:
if row[1] == '=':
if row[2] in used_once:
replace_with = (row[0], row[2])
continue
else:
if row[0] == replace_with[1]:
row[0] = replace_with[0]
count += 1
row[2:] = [replace_with[0] if _ == replace_with[1] else _ for _ in row[2:]]
result.append(row)
stmts = result[::-1]
if count == 0:
break
# Determine which variables are used as inputs
reads = defaultdict(int)
for row in stmts:
for var in row[2:]:
reads[var] += 1
never_written = set([a for a, b in reads.items() if assignments[a] == 0 or a not in assignments])
ignored = list()
while True:
# Find variables which can be overwritten later in the program
count = 0
count_before = len(set([row[0] for row in stmts]))
last_read = defaultdict(int)
for i, row in enumerate(stmts):
for var in row[2:]:
if var not in never_written and var not in outputs:
last_read[var] = i
# Then perform variable re-assignment, re-using variables
# TODO: verify we don't overwrite output variables
result = list()
replace_with = None
for i, row in enumerate(stmts):
free_vars = ([a for a, b in last_read.items() if b <= i])
if replace_with is None:
if row[1][0] != '_' and len(free_vars) and row[0] not in outputs and i not in ignored:
replace_with = (row[0], free_vars[0], i)
row[0] = free_vars[0]
count += 1
else:
row = row[:2] + [replace_with[1] if _ == replace_with[0] else _ for _ in row[2:]]
result.append(row)
stmts = result
count_after = len(set([row[0] for row in stmts]))
if count_before == count_after:
# Stop if we fail to reduce the var count
# Avoids getting stuck in loops shuffling variable names around
if replace_with is not None:
ignored.append(replace_with[2])
if count == 0:
break
# Turn statements into '_inplace'
result = list()
for row in stmts:
if row[1] != '=':
if len(row) == 4:
if row[0] in row[2:]:
row = row[:2] + [_ for _ in row[2:] if _ != row[0]]
row[1] = row[1] + '_inplace'
elif len(row) == 3:
if row[0] == row[2]:
row[1] = row[1] + '_inplace'
row = row[:2]
result.append(row)
stmts = result
return stmts
def analyze(lines, outputs):
print('---')
processed = list()
tmpvars = list()
def iter_expr(expr):
if not isinstance(expr, ExprBase):
# Strings will passthru
return expr
ops = {
'·': 'mul',
'*': 'mul',
'-': 'sub',
'−': 'sub', # Allows for easier copy-pasta from Latex PDFs
'+': 'add',
'/': 'div'
}
if isinstance(expr, (Expression, Group)):
lhs = iter_expr(expr.inside)
if isinstance(expr, Powerable):
if expr.power is not None:
power_arg = iter_expr(expr.inside)
tmpvar = f'tmp{len(tmpvars)}'
tmpvars.append(tmpvar)
if expr.power == '2':
processed.append((tmpvar, 'square', power_arg))
elif expr.power == '3':
processed.append((tmpvar, 'square', power_arg))
prev_tmpvar = tmpvar
tmpvar = f'tmp{len(tmpvars)}'
tmpvars.append(tmpvar)
processed.append((tmpvar, 'mul', prev_tmpvar, power_arg))
else:
processed.append((tmpvar, 'power', power_arg, int(expr.power)))
lhs = tmpvar
else:
# Powerable, without a power, just acts like a normal group
lhs = iter_expr(expr.inside)
if hasattr(expr, 'rhs'):
for item in expr.rhs:
tmpvar = f'tmp{len(tmpvars)}'
tmpvars.append(tmpvar)
if isinstance(item, Powerable):
rhs_arg = iter_expr(item)
elif isinstance(item, Group):
rhs_arg = iter_expr(item.inside)
elif isinstance(item, OpExpr):
rhs_arg = item.rhs
pending = [
tmpvar,
ops[item.op],
iter_expr(lhs),
iter_expr(rhs_arg)]
pending[2:] = [int(_) if is_num(_) else _ for _ in pending[2:]]
has_int_args = len([_ for _ in pending[2:] if isinstance(_, int)])
if has_int_args:
int_arg = [_ for _ in pending[2:] if isinstance(_, int)][0]
nonint_arg = [_ for _ in pending[2:] if not isinstance(_, int)][0]
if pending[1] == 'div':
if int_arg == 1:
pending[1] = 'inverse'
pending[2] = nonint_arg
pending = pending[:3]
else:
tmpvar2 = f'_const{len(tmpvars)}'
tmpvars.append(tmpvar2)
processed.append([tmpvar2, '_invert_const', int_arg])
pending = [tmpvar, 'mul_const', nonint_arg, tmpvar2]
has_int_args = False
if pending[1] in ['mul','add','sub']:
if has_int_args:
if pending[1] == 'mul' and int_arg > 0 and bin(int_arg)[2:].count('1') == 1:
# Exact power of 2, allows optimisation of doubling-inplace
pending[1] = f'{pending[1]}_{int_arg}'
pending[2] = nonint_arg
pending = pending[:3]
elif has_int_args == 1:
# One argument is a short integer, further optimisation
pending[1] += '_const'
pending[2] = nonint_arg
pending[3] = int_arg
processed.append(pending)
lhs = tmpvar
return lhs
for line in lines:
if isinstance(line, Assignment):
expr = line.expr
rhs = iter_expr(expr)
if rhs == '1':
processed.append((line.output, 'set_one'))
elif rhs == '0':
processed.append((line.output, 'set_zero'))
else:
processed.append((line.output, '=', rhs))
else:
raise RuntimeError("Unknown type")
print('-----')
# Display before preprocessing
"""
for i, row in enumerate(processed):
print(i, row)
print('-----')
"""
processed = postprocess(processed, outputs, tmpvars)
#"""
for i, row in enumerate(processed):
print(i, row)
#"""
if __name__ == "__main__":
code = """
A = (X2-X1)^2
B = X1*A
C = X2*A
D = (Y2-Y1)^2
X3 = D-B-C
Y3 = (Y2-Y1)*(B-X3)-Y1*(C-B)
Z3 = Z1*(X2-X1)
"""
code = """
A = 1/Z1
AA = A^2
X3 = X1*AA
Y3 = Y1*AA*A
Z3 = 1
"""
code = """
XX = X1^2
YY = Y1^2
ZZ = Z1^2
YYYY = YY^2
M = 3*XX+a*ZZ^2
MM = M^2
E = 6*((X1+YY)^2-XX-YYYY)-MM
EE = E^2
T = 16*YYYY
U = (M+E)^2-MM-EE-T
X3 = 4*(X1*EE-4*YY*U)
Y3 = 8*Y1*(U*(T-U)-E*EE)
Z3 = (Z1+E)^2-ZZ-EE
"""
code = """
S = 4*X1*Y1^2
M = 3*X1^2+a*Z1^4
T = M2-2*S
X3 = T
Y3 = M*(S-T)-8*Y1^4
Z3 = 2*Y1*Z1
"""
code = """
t0 ← X1 · X2
t1 ← Y1 · Y2
t2 ← Z1 · Z2
t3 ← X1 + Y1
t4 ← X2 + Y2
t3 ← t3 · t4
t4 ← t0 + t1
t3 ← t3 − t4
t4 ← X1 + Z1
t5 ← X2 + Z2
t4 ← t4 · t5
t5 ← t0 + t2
t4 ← t4 − t5
t5 ← Y1 + Z1
X3 ← Y2 + Z2
t5 ← t5 · X3
X3 ← t1 + t2
t5 ← t5 − X3
Z3 ← a · t4
X3 ← b3 · t2
Z3 ← X3 + Z3
X3 ← t1 − Z3
Z3 ← t1 + Z3
Y3 ← X3 · Z3
t1 ← t0 + t0
t1 ← t1 + t0
t2 ← a · t2
t4 ← b3 · t4
t1 ← t1 + t2
t2 ← t0 − t2
t2 ← a · t2
t4 ← t4 + t2
t0 ← t1 · t4
Y3 ← Y3 + t0
t0 ← t5 · t4
X3 ← t3 · X3
X3 ← X3 − t0
t0 ← t3 · t1
Z3 ← t5 · Z3
Z3 ← Z3 + t0"""
code = """
U1 = X1*Z2^2
U2 = X2*Z1^2
S1 = Y1*Z2^3
S2 = Y2*Z1^3
P = U2-U1
R = S2-S1
X3 = R^2-(U1+U2)*P^2
Y3 = (R*(-2*R^2+3*P^2*(U1+U2))-P^3*(S1+S2))/2
Z3 = Z1*Z2*P
"""
code = """
A = Z1^2
B = Z2^2
C = (Z1+Z2)^2-A-B
D = X1*Z2
E = X2*Z1
F = Y1*B
G = Y2*A
H = D-E
I = 2*(F-G)
II = I^2
J = C*H
K = 4*J*H
X3 = 2*II-(D+E)*K
JJ = J^2
Y3 = ((J+I)^2-JJ-II)*(D*K-X3)-F*K^2
Z3 = 2*JJ
"""
stmts = list()
for line in code.split("\n"):
line = line.strip()
if not len(line):
continue
if line[0] == '#':
continue
result = parse(line, Assignment)
stmts.append(result)
print(result)
analyze(stmts, ['X3', 'Y3', 'Z3'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment