Skip to content

Instantly share code, notes, and snippets.

@robertknight
Created December 29, 2023 18:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save robertknight/d95b9a6c6ac79ef8bf64cea9d534b177 to your computer and use it in GitHub Desktop.
Save robertknight/d95b9a6c6ac79ef8bf64cea9d534b177 to your computer and use it in GitHub Desktop.
Annotated AVX-512 gemm kernel
# At entry params are:
#
# tile_ptr (rdi)
# tile_row_stride (rsi)
# a (rdx, rcx)
# b (r8, r9)
# depth (stack)
# alpha (xmm0)
# beta (xmm1)
.section __TEXT,__text,regular,pure_instructions
.p2align 4, 0x90
wasnn::gemm::kernels::x64::Avx512Kernel::kernel_avx_512:
Lfunc_begin758:
.cfi_startproc
push rbp
.cfi_def_cfa_offset 16
.cfi_offset rbp, -16
mov rbp, rsp
.cfi_def_cfa_register rbp
mov rax, qword ptr [rbp + 16] # Load `depth` into rax
lea r10, [rax + rax]
lea r10, [r10 + 2*r10] # Set r10 = `depth * MR` (where MR == 6)
cmp r10, rcx # Compare `a.len()` with `depth * MR`
ja LBB758_15
mov rcx, rax # Set rcx = `depth * NR` (where NR == 2)
shl rcx, 5
cmp rcx, r9
ja LBB758_16
test rax, rax # Check if there are any loop iterations
je LBB758_3
add rdx, 20
add r8, 64
# Clear registers that hold `tmp`. The registers that hold `b_rows` don't
# need to be cleared because they are dead stores.
vxorps xmm2, xmm2, xmm2
vxorps xmm3, xmm3, xmm3
vxorps xmm4, xmm4, xmm4
vxorps xmm5, xmm5, xmm5
vxorps xmm6, xmm6, xmm6
vxorps xmm7, xmm7, xmm7
vxorps xmm8, xmm8, xmm8
vxorps xmm9, xmm9, xmm9
vxorps xmm10, xmm10, xmm10
vxorps xmm11, xmm11, xmm11
vxorps xmm12, xmm12, xmm12
vxorps xmm13, xmm13, xmm13
.p2align 4, 0x90
LBB758_8:
# Load `b_rows[i]`
vmovups zmm14, zmmword ptr [r8 - 64]
vmovups zmm15, zmmword ptr [r8]
# tmp[i][j] = fmadd(broadcast(a[i]), b_rows[j])
vbroadcastss zmm16, dword ptr [rdx - 20]
vfmadd231ps zmm13, zmm16, zmm14
vfmadd231ps zmm12, zmm15, zmm16
vbroadcastss zmm16, dword ptr [rdx - 16]
vfmadd231ps zmm11, zmm16, zmm14
vfmadd231ps zmm10, zmm15, zmm16
vbroadcastss zmm16, dword ptr [rdx - 12]
vfmadd231ps zmm9, zmm16, zmm14
vfmadd231ps zmm8, zmm15, zmm16
vbroadcastss zmm16, dword ptr [rdx - 8]
vfmadd231ps zmm7, zmm16, zmm14
vfmadd231ps zmm6, zmm15, zmm16
vbroadcastss zmm16, dword ptr [rdx - 4]
vfmadd231ps zmm5, zmm16, zmm14
vfmadd231ps zmm4, zmm15, zmm16
vbroadcastss zmm16, dword ptr [rdx]
vfmadd231ps zmm3, zmm16, zmm14
vfmadd231ps zmm2, zmm16, zmm15
add rdx, 24
sub r8, -128
dec rax
jne LBB758_8 # Jump to top of depth loop if not final iteration
vucomiss xmm0, dword ptr [rip + LCPI758_0]
jne LBB758_9
jnp LBB758_5
jmp LBB758_9
LBB758_3:
vxorps xmm2, xmm2, xmm2
vxorps xmm3, xmm3, xmm3
vxorps xmm4, xmm4, xmm4
vxorps xmm5, xmm5, xmm5
vxorps xmm6, xmm6, xmm6
vxorps xmm7, xmm7, xmm7
vxorps xmm8, xmm8, xmm8
vxorps xmm9, xmm9, xmm9
vxorps xmm10, xmm10, xmm10
vxorps xmm11, xmm11, xmm11
vxorps xmm12, xmm12, xmm12
vxorps xmm13, xmm13, xmm13
# Test if `alpha == 1`
vucomiss xmm0, dword ptr [rip + LCPI758_0]
jne LBB758_9
jp LBB758_9
LBB758_5:
# Test if `beta == 0`
vxorps xmm14, xmm14, xmm14
vucomiss xmm1, xmm14
jne LBB758_9
jp LBB758_9
# Store `tmp[i][j]` to `tile_ptr`
vmovups zmmword ptr [rdi], zmm13
vmovups zmmword ptr [rdi + 64], zmm12
vmovups zmmword ptr [rdi + 4*rsi], zmm11
vmovups zmmword ptr [rdi + 4*rsi + 64], zmm10
vmovups zmmword ptr [rdi + 8*rsi], zmm9
vmovups zmmword ptr [rdi + 8*rsi + 64], zmm8
lea rax, [rsi + 2*rsi]
vmovups zmmword ptr [rdi + 4*rax], zmm7
vmovups zmmword ptr [rdi + 4*rax + 64], zmm6
lea rax, [rsi + 4*rsi]
shl rsi, 4
vmovups zmmword ptr [rdi + rsi], zmm5
vmovups zmmword ptr [rdi + rsi + 64], zmm4
vmovups zmmword ptr [rdi + 4*rax], zmm3
vmovups zmmword ptr [rdi + 4*rax + 64], zmm2
pop rbp
# See https://community.intel.com/t5/Intel-ISA-Extensions/What-is-the-status-of-VZEROUPPER-use/m-p/1098375
vzeroupper
ret
LBB758_9:
# Check if `beta == 1 && alpha == 1`
vucomiss xmm0, dword ptr [rip + LCPI758_0]
jne LBB758_12
jp LBB758_12
vucomiss xmm1, dword ptr [rip + LCPI758_0]
jne LBB758_12
jp LBB758_12
vaddps zmm0, zmm13, zmmword ptr [rdi]
vmovups zmmword ptr [rdi], zmm0
vaddps zmm0, zmm12, zmmword ptr [rdi + 64]
vmovups zmmword ptr [rdi + 64], zmm0
vaddps zmm0, zmm11, zmmword ptr [rdi + 4*rsi]
vmovups zmmword ptr [rdi + 4*rsi], zmm0
vaddps zmm0, zmm10, zmmword ptr [rdi + 4*rsi + 64]
vmovups zmmword ptr [rdi + 4*rsi + 64], zmm0
vaddps zmm0, zmm9, zmmword ptr [rdi + 8*rsi]
vmovups zmmword ptr [rdi + 8*rsi], zmm0
vaddps zmm0, zmm8, zmmword ptr [rdi + 8*rsi + 64]
vmovups zmmword ptr [rdi + 8*rsi + 64], zmm0
lea rax, [rsi + 2*rsi]
vaddps zmm0, zmm7, zmmword ptr [rdi + 4*rax]
vmovups zmmword ptr [rdi + 4*rax], zmm0
vaddps zmm0, zmm6, zmmword ptr [rdi + 4*rax + 64]
vmovups zmmword ptr [rdi + 4*rax + 64], zmm0
lea rax, [rsi + 4*rsi]
shl rsi, 4
vaddps zmm0, zmm5, zmmword ptr [rdi + rsi]
vmovups zmmword ptr [rdi + rsi], zmm0
vaddps zmm0, zmm4, zmmword ptr [rdi + rsi + 64]
vmovups zmmword ptr [rdi + rsi + 64], zmm0
vaddps zmm0, zmm3, zmmword ptr [rdi + 4*rax]
vmovups zmmword ptr [rdi + 4*rax], zmm0
vaddps zmm0, zmm2, zmmword ptr [rdi + 4*rax + 64]
vmovups zmmword ptr [rdi + 4*rax + 64], zmm0
pop rbp
vzeroupper
ret
LBB758_12:
vbroadcastss zmm0, xmm0
vbroadcastss zmm1, xmm1
vmulps zmm14, zmm1, zmmword ptr [rdi]
vfmadd213ps zmm13, zmm0, zmm14
vmovups zmmword ptr [rdi], zmm13
vmulps zmm13, zmm1, zmmword ptr [rdi + 64]
vfmadd213ps zmm12, zmm0, zmm13
vmovups zmmword ptr [rdi + 64], zmm12
vmulps zmm12, zmm1, zmmword ptr [rdi + 4*rsi]
vfmadd213ps zmm11, zmm0, zmm12
vmovups zmmword ptr [rdi + 4*rsi], zmm11
vmulps zmm11, zmm1, zmmword ptr [rdi + 4*rsi + 64]
vfmadd213ps zmm10, zmm0, zmm11
vmovups zmmword ptr [rdi + 4*rsi + 64], zmm10
vmulps zmm10, zmm1, zmmword ptr [rdi + 8*rsi]
vfmadd213ps zmm9, zmm0, zmm10
vmovups zmmword ptr [rdi + 8*rsi], zmm9
vmulps zmm9, zmm1, zmmword ptr [rdi + 8*rsi + 64]
vfmadd213ps zmm8, zmm0, zmm9
vmovups zmmword ptr [rdi + 8*rsi + 64], zmm8
lea rax, [rsi + 2*rsi]
vmulps zmm8, zmm1, zmmword ptr [rdi + 4*rax]
vfmadd213ps zmm7, zmm0, zmm8
vmovups zmmword ptr [rdi + 4*rax], zmm7
vmulps zmm7, zmm1, zmmword ptr [rdi + 4*rax + 64]
vfmadd213ps zmm6, zmm0, zmm7
vmovups zmmword ptr [rdi + 4*rax + 64], zmm6
lea rax, [rsi + 4*rsi]
shl rsi, 4
vmulps zmm6, zmm1, zmmword ptr [rdi + rsi]
vfmadd213ps zmm5, zmm0, zmm6
vmovups zmmword ptr [rdi + rsi], zmm5
vmulps zmm5, zmm1, zmmword ptr [rdi + rsi + 64]
vfmadd213ps zmm4, zmm0, zmm5
vmovups zmmword ptr [rdi + rsi + 64], zmm4
vmulps zmm4, zmm1, zmmword ptr [rdi + 4*rax]
vfmadd213ps zmm3, zmm0, zmm4
vmovups zmmword ptr [rdi + 4*rax], zmm3
vmulps zmm1, zmm1, zmmword ptr [rdi + 4*rax + 64]
vfmadd213ps zmm2, zmm0, zmm1
vmovups zmmword ptr [rdi + 4*rax + 64], zmm2
pop rbp
vzeroupper
ret
LBB758_15:
lea rdi, [rip + l___unnamed_545]
lea rdx, [rip + l___unnamed_551]
mov esi, 39
call core::panicking::panic
LBB758_16:
lea rdi, [rip + l___unnamed_547]
lea rdx, [rip + l___unnamed_552]
mov esi, 39
call core::panicking::panic
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment