Skip to content

Instantly share code, notes, and snippets.

@zwegner
Created July 8, 2020 21:03
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save zwegner/721063c2356b558c5b11b232baf7b805 to your computer and use it in GitHub Desktop.
Assembly generator script in Python, for making a fast wc -w
import collections
import contextlib
import sys
# Register class, for GPRs and vector registers
_Reg = collections.namedtuple('Reg', 't v')
class Reg(_Reg):
def __str__(self):
names = ['ax', 'cx', 'dx', 'bx', 'sp', 'bp', 'si', 'di']
if self.t == 'r' and self.v < 8:
return 'r' + names[self.v]
return '%s%s' % (self.t, self.v)
GPR = lambda i: Reg('r', i)
XMM = lambda i: Reg('xmm', i)
YMM = lambda i: Reg('ymm', i)
INST_ID = 0
INSTS = []
# Instruction creation. This create a unique ID for hashability/dependency tracking, and
# appends the instruction to a global list
class Inst:
def __init__(self, mnem, *args, depend=None):
global INST_ID
INST_ID += 1
self.id = INST_ID
self.mnem = mnem
self.args = args
self.deps = {depend} if depend is not None else set()
INSTS.append(self)
def __eq__(self, other):
return self.id == other.id
def __hash__(self):
return self.id
def __str__(self):
return self.mnem + ' ' + ', '.join(['%#x' % a if isinstance(a, int) else str(a) for a in self.args])
@contextlib.contextmanager
def capture():
global INSTS
old_insts = INSTS
INSTS = []
yield INSTS
INSTS = old_insts
# Create instruction shortcuts
for mnem in ['prefetcht0', 'prefetcht1', 'prefetcht2', 'prefetchnta',
'vmovdqu', 'vpaddb', 'vpaddq', 'vpalignr', 'vpandn',
'vpbroadcastb', 'vpcmpeqb', 'vpcmpgtb', 'vperm2i128', 'vpor', 'vpxor', 'vpsadbw', 'vpsubb',
'vmovq', 'mov', 'lea']:
def bind(mnem):
globals()[mnem] = lambda *a, **k: Inst(mnem, *a, **k)
bind(mnem)
def vzero(reg):
vpxor(reg, reg, reg)
# Pseudo-nop: this inserts a nop that affects scheduling, by yielding the instruction slot
# in the block-interleaving scheduler we use. This isn't actually emitted as an instruction.
def pseudo_nop():
INSTS.append(None)
# Knobs
use_sub = 1
use_index = 1
use_next_base = 1
scale = 7
#next_base_offset = 0x1000
prefetch_interleave = [2, 2]
prefetch_32 = [0, 0]
prefetch_offset = [30*32, 6*32]
prefetch = [prefetchnta, prefetcht0]
prefetch_reverse = 0
prefetch_len = 2
next_base_offset = prefetch_offset[0]
pf_params = list(zip(prefetch_interleave, prefetch, prefetch_offset, prefetch_32))
if prefetch_reverse:
pf_params = pf_params[::-1]
pf_params = pf_params[:prefetch_len]
unroll = 8
lockstep = 1
schedule_break = lambda b, i, t: not t & 1
schedule_break = lambda b, i, t: b & 1 and not t & 1
schedule_break = lambda b, i, t: 0
#schedule_break = lambda b, i, t: b & 1
# Registers
# Constants
zero = YMM(1)
c0 = YMM(2)
c1 = YMM(3)
c2 = YMM(4)
# Total, subtotal
total = YMM(0)
subtotal = YMM(5)
# Last iteration's input as a vector register. This is
phi_last = last = YMM(6)
last_dep = None
# GPRs to make indexing use less encoding bytes
index = GPR(1) if use_index else None
next_base = GPR(3) if use_next_base else None
blocks = []
# Create a pointer using scale/index/base/displacement, as controlled by various knobs
def ptr(offset):
base = 'rdi'
if use_next_base and offset > abs(offset - next_base_offset):
base = next_base
offset -= next_base_offset
d = offset >> scale
if not use_index or d <= 0:
return '[%s%+#03x]' % (base, offset)
# Round up scale to the next power of two
d = max(s for s in [1, 2, 4, 8] if s <= d)
return '[%s+%s*%s%+#03x]' % (base, d, index, offset - d*(1<<scale))
# Base dependencies--so different blocks that use the same registers are sequenced
base_deps = {}
# Create blocks of instructions, one for each iteration of the (rolled) loop.
# We collect each block of instructions, and schedule them later.
# One loop kernel takes two registers (marginally), and since we have ~7 registers
# worth of constants/counters outside the main loop, we can interleave four copies
# of the loop in the remaining 9 registers we get in AVX2.
for i in range(unroll):
# Input pointer offset
offset = 0x20 * i
# Our two registers for this block, with some aliases
reg = 8+(i % 4)*2
d = YMM(reg)
e = mask = shifted = YMM(reg+1)
# First sub-iteration: compute directly into the subtotal
if i == 0 and not use_sub:
mask = subtotal
# Insert prefetches for both cache tiers (if they're interleaved for that level)
for [pf_interleave, pf, pf_offset, pf_32] in pf_params:
if pf_interleave == 1:
# XXX inserting extra prefetches can help? This does two per cache line.
# Maybe helps the scheduler? i.e. the cpu's scheduler. Even the
# pseudo-nop inserted here otherwise, to keep our scheduler in sync, doesn't help
if pf_32 or not i & 1:
pf('BYTE PTR %s' % ptr(offset + pf_offset),
depend=base_deps.get(reg) if 1 else None)
elif lockstep:
pseudo_nop()
# Loop kernel
vmovdqu(d, 'YMMWORD PTR %s' % ptr(offset), depend=base_deps.get(reg))
vpaddb(e, d, c0)
vpcmpgtb(e, e, c1)
vpcmpeqb(d, d, c2)
vpor(d, e, d)
last_dep = vperm2i128(shifted, d, last, 0x03, depend=last_dep)
vpalignr(shifted, d, shifted, 0xf)
last = d
base_deps[reg] = vpandn(mask, d, shifted)
if use_sub:
base_deps[reg] = vpsubb(subtotal, subtotal, mask)
elif i > 0:
base_deps[reg] = vpaddb(subtotal, subtotal, mask)
elif lockstep:
pseudo_nop()
# Last of the group of four blocks: move the compare result into the register
# the next block is expecting
if i & 3 == 3:
base_deps[reg] = vmovdqu(phi_last, last)
last = phi_last
elif lockstep:
pseudo_nop()
# Grab all current instructions for this block
blocks.append(INSTS)
INSTS = []
# Schedule blocks by interleaving their instructions
def schedule():
while True:
block_idx = [0] * len(blocks)
for [b, block] in enumerate(blocks):
if block:
inst = block[0]
if not inst or all(dep in INSTS for dep in inst.deps):
if inst:
INSTS.append(inst)
block.pop(0)
block_idx[b] += 1
if schedule_break(b, block_idx[b], len(INSTS)):
break
if not any(blocks):
break
# Prologue
INSTS = []
vzero(total)
Inst('cmp rdi,rsi')
Inst('jae L2')
vzero(zero)
vzero(last)
if use_index:
mov(index, 1<<scale)
if use_next_base:
lea(next_base, '[rdi+0x%x]' % next_base_offset)
# Single byte broadcasted constants--move to gpr then to ymm
BYTE_CONSTS = {}
for [i, [y, c]] in enumerate([[c0, 0x72], [c1, 0x7a], [c2, 0x20]]):
name = 'c%s' % i
BYTE_CONSTS[name] = c
vpbroadcastb(y, '[rip+%s]' % name)
Inst('.align 4')
Inst('L1:')
# Loop beginning: set up address registers, subtotal, prefetches
if use_sub:
vzero(subtotal)
# Insert prefetches for each tier if they're not interleaved
for [pf_interleave, pf, pf_offset, pf_32] in pf_params:
if not pf_interleave:
for offset in range(0, 32*unroll, 32 if pf_32 else 64):
pf('BYTE PTR %s' % ptr(offset + pf_offset))
# Schedule the main loop blocks
with capture() as loop_insts:
schedule()
# Prefetch interleaving
prefetch_groups = []
for [pf_interleave, pf, pf_offset, pf_32] in pf_params:
with capture() as prefetches:
if pf_interleave == 2:
for offset in range(0, 32*unroll, 32 if pf_32 else 64):
pf('BYTE PTR %s' % ptr(offset + pf_offset))
prefetch_groups.append(prefetches)
for prefetches in prefetch_groups:
for [i, pf] in enumerate(prefetches):
i = (i * len(loop_insts)) // len(prefetches)
loop_insts[i:i] = [pf]
INSTS.extend(loop_insts)
# Loop end: horizontal sum of the subtotal from bytes to qwords, add into the total
if use_next_base:
lea(next_base, '[%s+0x%x]' % (next_base, unroll*32))
if not use_sub:
vpsubb(subtotal, zero, subtotal)
vpsadbw(subtotal, zero, subtotal)
vpaddq(total, subtotal, total)
# Output assembly
with open(sys.argv[1], 'w') as f:
f.write('''
.intel_syntax noprefix
.global _tokenize_zasm
_tokenize_zasm:
push rbp
mov rbp,rsp
''')
for inst in INSTS:
print(' ', inst, file=f)
f.write('''
lea rdi,[rdi+{offset}]
cmp rdi,rsi
jb L1
L2:
vextracti128 xmm1,{total},0x1
vpaddq {total},{total},ymm1
vpshufd xmm1,xmm0,0x4e
vpaddq {total},{total},ymm1
vmovq rax,xmm0
pop rbp
vzeroupper
ret
'''.format(offset=unroll*32, total=total))
for [name, value] in BYTE_CONSTS.items():
f.write('%s: .byte %s\n' % (name, value))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment