Created
October 23, 2019 13:03
-
-
Save HarryR/c3eee3a6ed009fd4f9772e98c2a24010 to your computer and use it in GitHub Desktop.
Parse simple(ish) algorithms from whitepapers, do basic optimisations to transform into consistent state
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
from pypeg2 import word, attr, maybe_some, blank, endl, parse, optional | |
from collections import defaultdict | |
class ExprBase(str): | |
pass | |
class Expression(ExprBase): | |
def __str__(self): | |
return str(self.inside) + ''.join([str(_) for _ in self.rhs]) | |
class Group(ExprBase): | |
def __str__(self): | |
return '(' + str(self.inside) + ')' | |
class Powerable(ExprBase): | |
grammar = attr('inside', [word, Group]), attr('power', optional('^', re.compile(r'[1-9][0-9]*'))) | |
def __str__(self): | |
res = str(self.inside) | |
if self.power: | |
res += '^' + self.power | |
return res | |
def is_num(x): | |
if len(x): | |
if x[0] == '-': | |
return x[1:].isdigit() | |
return x.isdigit() | |
return False | |
class OpExpr(ExprBase): | |
grammar = attr('op', re.compile(r'[·\+\-\*−/]')), attr('rhs', [Powerable, re.compile(r'-?[1-9][0-9]*')]) | |
def __str__(self): | |
return f' {self.op} {self.rhs}' | |
Group.grammar = '(', attr('inside', Expression), ')' | |
Expression.grammar = attr('inside', [Powerable, re.compile(r'-?[1-9][0-9]*')]), attr('rhs', maybe_some(OpExpr)) | |
class Assignment(ExprBase): | |
grammar = attr('output', word), blank, ['=', '←'], blank, attr('expr', Expression), endl | |
def __str__(self): | |
return f'{self.output} = {self.expr}' | |
def postprocess(stmts, outputs, tmpvars): | |
# Convert variables to SSA form first | |
while True: | |
count = 0 | |
assignments = defaultdict(list) | |
for i, row in enumerate(stmts): | |
assignments[row[0]].append(i) | |
result = list() | |
replace_with = None | |
skipall = False | |
for i, row in enumerate(stmts): | |
row = list(row) | |
if not skipall: | |
if replace_with is None: | |
if len(assignments[row[0]]) > 0 and i != assignments[row[0]][-1]: | |
tmpvar = f'tmp{len(tmpvars)}' | |
tmpvars.append(tmpvar) | |
replace_with = (row[0], tmpvar, assignments[row[0]][1]) | |
row[0] = tmpvar | |
count += 1 | |
else: | |
if row[0] == replace_with[0]: | |
skipall = True | |
else: | |
row[2:] = [replace_with[1] if _ == replace_with[0] else _ for _ in row[2:]] | |
result.append(row) | |
stmts = result | |
if count == 0: | |
break | |
#""" | |
# Then perform expression deduplication | |
# we can freely change variable names here via a rewrite mapping | |
exprcache = dict() | |
cacherewrite = dict() | |
results = list() | |
for i, row in enumerate(stmts): | |
rhs = tuple(row[1:]) | |
if row[1] != '=': | |
if rhs in exprcache: | |
# If statement has already been computed | |
# Change to an assignment... this will be optimised out later | |
row = [row[0], '=', exprcache[rhs]] | |
else: | |
exprcache[rhs] = row[0] | |
results.append(row) | |
stmts = results | |
#""" | |
# Iterate through statements in reverse, removing needless copies | |
while True: | |
# Count number of times each variable is assigned | |
assignments = defaultdict(int) | |
for row in stmts: | |
assignments[row[0]] += 1 | |
used_once = set([a for a, b in assignments.items() if b == 1]) | |
count = 0 | |
result = list() | |
replace_with = None | |
for row in stmts[::-1]: | |
row = list(row) | |
if replace_with is None: | |
if row[1] == '=': | |
if row[2] in used_once: | |
replace_with = (row[0], row[2]) | |
continue | |
else: | |
if row[0] == replace_with[1]: | |
row[0] = replace_with[0] | |
count += 1 | |
row[2:] = [replace_with[0] if _ == replace_with[1] else _ for _ in row[2:]] | |
result.append(row) | |
stmts = result[::-1] | |
if count == 0: | |
break | |
# Determine which variables are used as inputs | |
reads = defaultdict(int) | |
for row in stmts: | |
for var in row[2:]: | |
reads[var] += 1 | |
never_written = set([a for a, b in reads.items() if assignments[a] == 0 or a not in assignments]) | |
ignored = list() | |
while True: | |
# Find variables which can be overwritten later in the program | |
count = 0 | |
count_before = len(set([row[0] for row in stmts])) | |
last_read = defaultdict(int) | |
for i, row in enumerate(stmts): | |
for var in row[2:]: | |
if var not in never_written and var not in outputs: | |
last_read[var] = i | |
# Then perform variable re-assignment, re-using variables | |
# TODO: verify we don't overwrite output variables | |
result = list() | |
replace_with = None | |
for i, row in enumerate(stmts): | |
free_vars = ([a for a, b in last_read.items() if b <= i]) | |
if replace_with is None: | |
if row[1][0] != '_' and len(free_vars) and row[0] not in outputs and i not in ignored: | |
replace_with = (row[0], free_vars[0], i) | |
row[0] = free_vars[0] | |
count += 1 | |
else: | |
row = row[:2] + [replace_with[1] if _ == replace_with[0] else _ for _ in row[2:]] | |
result.append(row) | |
stmts = result | |
count_after = len(set([row[0] for row in stmts])) | |
if count_before == count_after: | |
# Stop if we fail to reduce the var count | |
# Avoids getting stuck in loops shuffling variable names around | |
if replace_with is not None: | |
ignored.append(replace_with[2]) | |
if count == 0: | |
break | |
# Turn statements into '_inplace' | |
result = list() | |
for row in stmts: | |
if row[1] != '=': | |
if len(row) == 4: | |
if row[0] in row[2:]: | |
row = row[:2] + [_ for _ in row[2:] if _ != row[0]] | |
row[1] = row[1] + '_inplace' | |
elif len(row) == 3: | |
if row[0] == row[2]: | |
row[1] = row[1] + '_inplace' | |
row = row[:2] | |
result.append(row) | |
stmts = result | |
return stmts | |
def analyze(lines, outputs): | |
print('---') | |
processed = list() | |
tmpvars = list() | |
def iter_expr(expr): | |
if not isinstance(expr, ExprBase): | |
# Strings will passthru | |
return expr | |
ops = { | |
'·': 'mul', | |
'*': 'mul', | |
'-': 'sub', | |
'−': 'sub', # Allows for easier copy-pasta from Latex PDFs | |
'+': 'add', | |
'/': 'div' | |
} | |
if isinstance(expr, (Expression, Group)): | |
lhs = iter_expr(expr.inside) | |
if isinstance(expr, Powerable): | |
if expr.power is not None: | |
power_arg = iter_expr(expr.inside) | |
tmpvar = f'tmp{len(tmpvars)}' | |
tmpvars.append(tmpvar) | |
if expr.power == '2': | |
processed.append((tmpvar, 'square', power_arg)) | |
elif expr.power == '3': | |
processed.append((tmpvar, 'square', power_arg)) | |
prev_tmpvar = tmpvar | |
tmpvar = f'tmp{len(tmpvars)}' | |
tmpvars.append(tmpvar) | |
processed.append((tmpvar, 'mul', prev_tmpvar, power_arg)) | |
else: | |
processed.append((tmpvar, 'power', power_arg, int(expr.power))) | |
lhs = tmpvar | |
else: | |
# Powerable, without a power, just acts like a normal group | |
lhs = iter_expr(expr.inside) | |
if hasattr(expr, 'rhs'): | |
for item in expr.rhs: | |
tmpvar = f'tmp{len(tmpvars)}' | |
tmpvars.append(tmpvar) | |
if isinstance(item, Powerable): | |
rhs_arg = iter_expr(item) | |
elif isinstance(item, Group): | |
rhs_arg = iter_expr(item.inside) | |
elif isinstance(item, OpExpr): | |
rhs_arg = item.rhs | |
pending = [ | |
tmpvar, | |
ops[item.op], | |
iter_expr(lhs), | |
iter_expr(rhs_arg)] | |
pending[2:] = [int(_) if is_num(_) else _ for _ in pending[2:]] | |
has_int_args = len([_ for _ in pending[2:] if isinstance(_, int)]) | |
if has_int_args: | |
int_arg = [_ for _ in pending[2:] if isinstance(_, int)][0] | |
nonint_arg = [_ for _ in pending[2:] if not isinstance(_, int)][0] | |
if pending[1] == 'div': | |
if int_arg == 1: | |
pending[1] = 'inverse' | |
pending[2] = nonint_arg | |
pending = pending[:3] | |
else: | |
tmpvar2 = f'_const{len(tmpvars)}' | |
tmpvars.append(tmpvar2) | |
processed.append([tmpvar2, '_invert_const', int_arg]) | |
pending = [tmpvar, 'mul_const', nonint_arg, tmpvar2] | |
has_int_args = False | |
if pending[1] in ['mul','add','sub']: | |
if has_int_args: | |
if pending[1] == 'mul' and int_arg > 0 and bin(int_arg)[2:].count('1') == 1: | |
# Exact power of 2, allows optimisation of doubling-inplace | |
pending[1] = f'{pending[1]}_{int_arg}' | |
pending[2] = nonint_arg | |
pending = pending[:3] | |
elif has_int_args == 1: | |
# One argument is a short integer, further optimisation | |
pending[1] += '_const' | |
pending[2] = nonint_arg | |
pending[3] = int_arg | |
processed.append(pending) | |
lhs = tmpvar | |
return lhs | |
for line in lines: | |
if isinstance(line, Assignment): | |
expr = line.expr | |
rhs = iter_expr(expr) | |
if rhs == '1': | |
processed.append((line.output, 'set_one')) | |
elif rhs == '0': | |
processed.append((line.output, 'set_zero')) | |
else: | |
processed.append((line.output, '=', rhs)) | |
else: | |
raise RuntimeError("Unknown type") | |
print('-----') | |
# Display before preprocessing | |
""" | |
for i, row in enumerate(processed): | |
print(i, row) | |
print('-----') | |
""" | |
processed = postprocess(processed, outputs, tmpvars) | |
#""" | |
for i, row in enumerate(processed): | |
print(i, row) | |
#""" | |
if __name__ == "__main__": | |
code = """ | |
A = (X2-X1)^2 | |
B = X1*A | |
C = X2*A | |
D = (Y2-Y1)^2 | |
X3 = D-B-C | |
Y3 = (Y2-Y1)*(B-X3)-Y1*(C-B) | |
Z3 = Z1*(X2-X1) | |
""" | |
code = """ | |
A = 1/Z1 | |
AA = A^2 | |
X3 = X1*AA | |
Y3 = Y1*AA*A | |
Z3 = 1 | |
""" | |
code = """ | |
XX = X1^2 | |
YY = Y1^2 | |
ZZ = Z1^2 | |
YYYY = YY^2 | |
M = 3*XX+a*ZZ^2 | |
MM = M^2 | |
E = 6*((X1+YY)^2-XX-YYYY)-MM | |
EE = E^2 | |
T = 16*YYYY | |
U = (M+E)^2-MM-EE-T | |
X3 = 4*(X1*EE-4*YY*U) | |
Y3 = 8*Y1*(U*(T-U)-E*EE) | |
Z3 = (Z1+E)^2-ZZ-EE | |
""" | |
code = """ | |
S = 4*X1*Y1^2 | |
M = 3*X1^2+a*Z1^4 | |
T = M2-2*S | |
X3 = T | |
Y3 = M*(S-T)-8*Y1^4 | |
Z3 = 2*Y1*Z1 | |
""" | |
code = """ | |
t0 ← X1 · X2 | |
t1 ← Y1 · Y2 | |
t2 ← Z1 · Z2 | |
t3 ← X1 + Y1 | |
t4 ← X2 + Y2 | |
t3 ← t3 · t4 | |
t4 ← t0 + t1 | |
t3 ← t3 − t4 | |
t4 ← X1 + Z1 | |
t5 ← X2 + Z2 | |
t4 ← t4 · t5 | |
t5 ← t0 + t2 | |
t4 ← t4 − t5 | |
t5 ← Y1 + Z1 | |
X3 ← Y2 + Z2 | |
t5 ← t5 · X3 | |
X3 ← t1 + t2 | |
t5 ← t5 − X3 | |
Z3 ← a · t4 | |
X3 ← b3 · t2 | |
Z3 ← X3 + Z3 | |
X3 ← t1 − Z3 | |
Z3 ← t1 + Z3 | |
Y3 ← X3 · Z3 | |
t1 ← t0 + t0 | |
t1 ← t1 + t0 | |
t2 ← a · t2 | |
t4 ← b3 · t4 | |
t1 ← t1 + t2 | |
t2 ← t0 − t2 | |
t2 ← a · t2 | |
t4 ← t4 + t2 | |
t0 ← t1 · t4 | |
Y3 ← Y3 + t0 | |
t0 ← t5 · t4 | |
X3 ← t3 · X3 | |
X3 ← X3 − t0 | |
t0 ← t3 · t1 | |
Z3 ← t5 · Z3 | |
Z3 ← Z3 + t0""" | |
code = """ | |
U1 = X1*Z2^2 | |
U2 = X2*Z1^2 | |
S1 = Y1*Z2^3 | |
S2 = Y2*Z1^3 | |
P = U2-U1 | |
R = S2-S1 | |
X3 = R^2-(U1+U2)*P^2 | |
Y3 = (R*(-2*R^2+3*P^2*(U1+U2))-P^3*(S1+S2))/2 | |
Z3 = Z1*Z2*P | |
""" | |
code = """ | |
A = Z1^2 | |
B = Z2^2 | |
C = (Z1+Z2)^2-A-B | |
D = X1*Z2 | |
E = X2*Z1 | |
F = Y1*B | |
G = Y2*A | |
H = D-E | |
I = 2*(F-G) | |
II = I^2 | |
J = C*H | |
K = 4*J*H | |
X3 = 2*II-(D+E)*K | |
JJ = J^2 | |
Y3 = ((J+I)^2-JJ-II)*(D*K-X3)-F*K^2 | |
Z3 = 2*JJ | |
""" | |
stmts = list() | |
for line in code.split("\n"): | |
line = line.strip() | |
if not len(line): | |
continue | |
if line[0] == '#': | |
continue | |
result = parse(line, Assignment) | |
stmts.append(result) | |
print(result) | |
analyze(stmts, ['X3', 'Y3', 'Z3']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment