Skip to content

Instantly share code, notes, and snippets.

@bwesterb
Created October 24, 2020 13:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bwesterb/6d7b584c26cc1b716cd3cb18e70b7cde to your computer and use it in GitHub Desktop.
Save bwesterb/6d7b584c26cc1b716cd3cb18e70b7cde to your computer and use it in GitHub Desktop.
import cairo
import math
import sys
class State:
def __init__(self, name, groupSize=1024):
self.surface = cairo.SVGSurface (name+".svg", 1040, 1040)
self.name = name
self.ctx = cairo.Context (self.surface)
self.ctx.scale(4, 4)
self.ctx.translate(2, 2)
self.x = 0.
self.groupSize = groupSize
self.idxs = list(range(256))
def finish(self):
self.surface.finish()
def shuffle(self, idx, pairs):
bfs = []
for a, b in pairs:
if idx == 3:
for i in range(8):
bfs.append((a*16+8+i, b*16+i))
elif idx == 2:
for j in range(2):
for i in range(4):
bfs.append((a*16+4+i+8*j, b*16+i+8*j))
elif idx == 1:
for j in range(4):
for i in range(2):
bfs.append((a*16+2+i+4*j, b*16+i+4*j))
elif idx == 0:
for j in range(8):
bfs.append((a*16+1+2*j, b*16+2*j))
else:
assert False
for i in range(len(bfs)):
if bfs[i][0] > bfs[i][1]:
bfs[i] = (bfs[i][1], bfs[i][0])
self.swaps(bfs)
def bitflip(self, idx):
bfs = []
for i in range(256):
v1 = (i & (1 << idx)) << (4 - idx)
v2 = ((i & (1 << 4)) >> 4) << idx
j = (i | (1<<4) | (1<<idx)) & (v1 | (255 ^ 16)) & (v2 | (255 ^ (1 << idx)))
if i >= j:
continue
bfs.append((i,j))
self.swaps(bfs)
def butterflyX4(self, pairs, level):
bfs = []
for a, b in pairs:
for i in range(16):
bfs.append((16*a +i, 16*b+i))
self.butterflies(bfs, level)
def butterflies(self, bfs, level):
self.ctx.set_font_size(1)
self.ctx.set_source_rgb(0, 0, 0)
self.ctx.move_to(self.x, -0.5)
self.ctx.show_text("level %s" % level)
self.drawLines(bfs, False, level)
def swaps(self, bfs):
self.drawLines(bfs, True)
for i, j in bfs:
self.idxs[i], self.idxs[j] = self.idxs[j], self.idxs[i]
self.drawGridLines(self.x, self.x+2.0)
self.x -= 1
self.ctx.set_font_size(.5)
self.ctx.set_source_rgb (.5,.5,1.)
for i in range(256):
self.ctx.move_to(self.x, i-.15)
self.ctx.show_text(bin(self.idxs[i])[2:].zfill(8))
self.x += 2.5
def drawGridLines(self, oldX, newX):
for i in range(0, 256):
if i % self.groupSize == 0:
self.ctx.set_source_rgb (.7,0,0)
else:
self.ctx.set_source_rgb (.7,.7,.7)
self.ctx.move_to(oldX, i)
self.ctx.line_to(newX, i)
self.ctx.stroke()
def drawLines(self, bfs, swap=False, level=None):
self.ctx.set_line_width(0.1)
occupied = [0]*256
actions = []
for a, b in bfs:
a, b = min(a,b), max(a,b)
placement = 0
for v in range(a, b+1):
placement = max(placement, occupied[v])
for v in range(a, b+1):
occupied[v] = placement+1
actions.append((self.x+placement/2, a, b))
oldX = self.x
self.x += (max(occupied)/2+1)
self.drawGridLines(oldX, self.x)
self.ctx.set_source_rgb (0,0,0)
for x, a, b in actions:
self.ctx.move_to(x, a)
self.ctx.line_to(x, b)
self.ctx.stroke()
if swap:
self.ctx.move_to(x - 0.3, a - 0.3)
self.ctx.line_to(x + 0.3, a + 0.3)
self.ctx.stroke()
self.ctx.move_to(x + 0.3, a - 0.3)
self.ctx.line_to(x - 0.3, a + 0.3)
self.ctx.stroke()
self.ctx.move_to(x - 0.3, b - 0.3)
self.ctx.line_to(x + 0.3, b + 0.3)
self.ctx.stroke()
self.ctx.move_to(x + 0.3, b - 0.3)
self.ctx.line_to(x - 0.3, b + 0.3)
self.ctx.stroke()
else:
self.ctx.arc(x, a, 0.2, 0, 2*math.pi)
self.ctx.fill()
self.ctx.arc(x, b, 0.2, 0, 2*math.pi)
self.ctx.fill()
prevZeta = None
actions.sort(key=lambda action: action[1])
self.ctx.set_font_size(0.5)
if not swap:
for x, a, b in actions:
zeta = ((self.idxs[a]) >> (9-level)) + (1<<(level-1))
if zeta != prevZeta:
prevZeta = zeta
self.ctx.move_to(x-0.25, (a+b)/2+0.3)
self.ctx.text_path(str(zeta))
self.ctx.set_source_rgb(1, 1, 1)
self.ctx.set_line_width(0.2)
self.ctx.stroke_preserve()
self.ctx.set_source_rgb(0, 0, 0)
self.ctx.set_line_width(0.1)
self.ctx.fill()
for a, b in bfs:
a2 = self.idxs[a]
b2 = self.idxs[b]
if b2 != (a2 & (255 ^ (1 << (8-level)))) | (1<<(8-level)):
print("%s: Wrong butterfly on level %s: CT(%s, %s)" % (
self.name, level, bin(a2)[2:].zfill(8), bin(b2)[2:].zfill(8)))
sys.exit()
print()
def ref():
s = State("ref")
l = 256
level = 1
while l > 1:
l >>= 1
offset = 0
bfs = []
while offset < 256-l:
for j in range(offset, offset+l):
bfs.append((j, j+l))
offset += 2*l
s.butterflies(bfs, level)
level += 1
s.finish()
def dilavx2():
s = State("dilavx2", 8)
l = 256
level = 1
while l > 4:
l >>= 1
offset = 0
bfs = []
while offset < 256-l:
for j in range(offset, offset+l):
bfs.append((j, j+l))
offset += 2*l
s.butterflies(bfs, level)
level += 1
bfs = []
for i in range(32):
bfs.append((8*i+2, 8*i+4))
bfs.append((8*i+3, 8*i+5))
s.swaps(bfs)
bfs = []
for i in range(32):
bfs.append((8*i, 8*i+4))
bfs.append((8*i+1, 8*i+5))
bfs.append((8*i+2, 8*i+6))
bfs.append((8*i+3, 8*i+7))
s.butterflies(bfs, 7)
bfs = []
for i in range(32):
bfs.append((8*i+1, 8*i+4))
bfs.append((8*i+3, 8*i+6))
s.swaps(bfs)
bfs = []
for i in range(32):
bfs.append((8*i, 8*i+4))
bfs.append((8*i+1, 8*i+5))
bfs.append((8*i+2, 8*i+6))
bfs.append((8*i+3, 8*i+7))
s.butterflies(bfs, 8)
bfs = []
for i in range(32):
bfs.append((8*i+3, 8*i+6))
bfs.append((8*i+1, 8*i+4))
bfs.append((8*i+3, 8*i+5))
bfs.append((8*i+2, 8*i+4))
s.swaps(bfs)
s.finish()
def kybavx2():
s = State("kybavx2", 16)
s.butterflies([(i, i+128) for i in range(128)], 1)
bfs = []
for i in range(64):
bfs.append((i, i+64))
bfs.append((i+128, i+192))
s.butterflies(bfs, 2)
bfs = []
for i in range(32):
bfs.append((i, i+32))
bfs.append((i+64, i+64+32))
bfs.append((i+128, i+128+32))
bfs.append((i+64*3, i+64*3+32))
s.butterflies(bfs, 3)
def finalBfs(level):
bfs = []
for i in range(16):
for j in range(8):
bfs.append((i + 32*j, i + 32*j+16))
s.butterflies(bfs, level)
def rev(xs):
return [(b,a) for (a,b) in xs]
shufs1 = [(0,2), (1,3), (4,6), (5,7), (8,10), (9,11), (12,14), (13,15)]
shufs2 = [(0,1),(2,3), (4,5), (6,7), (8,9), (10,11), (12,13), (14,15)]
shufs3 = shufs1
shufs4 = shufs2
s.shuffle(3, shufs1)
finalBfs(4)
s.shuffle(2, shufs2)
s.butterflyX4(shufs1, 5)
s.shuffle(1, shufs3)
s.butterflyX4(shufs2, 6)
s.shuffle(0, shufs4)
s.butterflyX4(shufs3, 7)
s.finish()
ref()
dilavx2()
kybavx2()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment