Skip to content

Instantly share code, notes, and snippets.

@jayfoad
Created November 11, 2021 11:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jayfoad/94f4c68fa3a9aa908db79dbd7e9df80d to your computer and use it in GitHub Desktop.
Save jayfoad/94f4c68fa3a9aa908db79dbd7e9df80d to your computer and use it in GitHub Desktop.
#!/usr/bin/python3
import collections
import subprocess
import sys
args = 3
size = 3
def f_and(x, y):
return x & y
def f_or(x, y):
return x | y
def f_xor(x, y):
return x ^ y
ops = [('and', f_and), ('or', f_or), ('xor', f_xor)]
mask = (1 << (1 << args)) - 1
x = [[[('', mask // ((1 << (1 << i)) + 1)) for i in range(args)]]]
for i in range(size):
p = x[i]
t = [j + [(' %%%i = xor i4 %%%i, -1' % (args + i, k), mask & ~j[k][1])] for j in p for k in range(len(j))]
t += [j + [(' %%%i = %s i4 %%%i, %%%i' % (args + i, m[0], k, l), m[1](j[k][1], j[l][1]))] for j in p for k in range(len(j)) for l in range(k) for m in ops]
x += [t]
y = collections.defaultdict(list)
for i in x:
for j in i:
y[j[-1][1]].append([s for s, _ in j])
for z in y.values():
if len(z) > 1:
inlines = [['define i4 @f%i(' % i + ', '.join(['i4 %' + str(j) for j in range(args)]) + ') {', 'bb:'] + a[args:] + [' ret i4 %%%i' % (len(a) - 1), '}'] for a, i in zip(z, range(len(z)))]
instr = ''.join([j + '\n' for i in inlines for j in i])
proc = subprocess.run(['opt', '-S', '-O3'], input = instr, stdout = subprocess.PIPE, encoding = sys.getdefaultencoding())
accumulating = False
outlines = []
for l in proc.stdout.splitlines():
if l.startswith('define'):
accumulating = True
outlines += [[]]
if accumulating:
outlines[-1] += [l]
if l.startswith('}'):
accumulating = False
lens = list(map(len, outlines))
if min(lens) != max(lens):
# We found a bunch of equivalent functions that did not all get
# optimized to the same shortest form. Print the shortest one
# followed by all the longer ones.
print(';;;;;;;;')
for i in outlines:
if len(i) == min(lens):
print('\n'.join(i))
break
for i in outlines:
if len(i) > min(lens):
print('\n'.join(i))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment