Created
November 11, 2021 11:31
-
-
Save jayfoad/94f4c68fa3a9aa908db79dbd7e9df80d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 | |
import collections | |
import subprocess | |
import sys | |
args = 3 | |
size = 3 | |
def f_and(x, y): | |
return x & y | |
def f_or(x, y): | |
return x | y | |
def f_xor(x, y): | |
return x ^ y | |
ops = [('and', f_and), ('or', f_or), ('xor', f_xor)] | |
mask = (1 << (1 << args)) - 1 | |
x = [[[('', mask // ((1 << (1 << i)) + 1)) for i in range(args)]]] | |
for i in range(size): | |
p = x[i] | |
t = [j + [(' %%%i = xor i4 %%%i, -1' % (args + i, k), mask & ~j[k][1])] for j in p for k in range(len(j))] | |
t += [j + [(' %%%i = %s i4 %%%i, %%%i' % (args + i, m[0], k, l), m[1](j[k][1], j[l][1]))] for j in p for k in range(len(j)) for l in range(k) for m in ops] | |
x += [t] | |
y = collections.defaultdict(list) | |
for i in x: | |
for j in i: | |
y[j[-1][1]].append([s for s, _ in j]) | |
for z in y.values(): | |
if len(z) > 1: | |
inlines = [['define i4 @f%i(' % i + ', '.join(['i4 %' + str(j) for j in range(args)]) + ') {', 'bb:'] + a[args:] + [' ret i4 %%%i' % (len(a) - 1), '}'] for a, i in zip(z, range(len(z)))] | |
instr = ''.join([j + '\n' for i in inlines for j in i]) | |
proc = subprocess.run(['opt', '-S', '-O3'], input = instr, stdout = subprocess.PIPE, encoding = sys.getdefaultencoding()) | |
accumulating = False | |
outlines = [] | |
for l in proc.stdout.splitlines(): | |
if l.startswith('define'): | |
accumulating = True | |
outlines += [[]] | |
if accumulating: | |
outlines[-1] += [l] | |
if l.startswith('}'): | |
accumulating = False | |
lens = list(map(len, outlines)) | |
if min(lens) != max(lens): | |
# We found a bunch of equivalent functions that did not all get | |
# optimized to the same shortest form. Print the shortest one | |
# followed by all the longer ones. | |
print(';;;;;;;;') | |
for i in outlines: | |
if len(i) == min(lens): | |
print('\n'.join(i)) | |
break | |
for i in outlines: | |
if len(i) > min(lens): | |
print('\n'.join(i)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment