Skip to content

Instantly share code, notes, and snippets.

@rygorous
Created August 9, 2019 23:08
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save rygorous/a86a5cf348922cdea357c928e32fc7e0 to your computer and use it in GitHub Desktop.
Save rygorous/a86a5cf348922cdea357c928e32fc7e0 to your computer and use it in GitHub Desktop.
Histogram code with all the tricks :) Needs NASM + VC++
@echo off
setlocal
cd %~dp0
call vcvars amd64
..\..\bin\win32\nasm -f win64 -g -o histo_asm.obj histo_asm.nas || exit /b 1
cl /Zi /O2 /nologo histotest.cpp histo_asm.obj || exit /b 1
; NOTE: all just a quick sketch
; also I'm using xmm/ymm/zmm0-5 and then 16-31
; because Win64 ABI expects me to preserve ymm6-ymm15
; and I couldn't be bothered to set up a stack frame and save them :)
;
; Meant to be assembled with NASM
%macro IACA_START 0
mov ebx,111
db 0x64,0x67,0x90
%endmacro
%macro IACA_END 0
mov ebx,222
db 0x64,0x67,0x90
%endmacro
section .data
vone dd 1
vtwo dd 2
v31 dd 31
vneg1 dd -1
vmaskb dd 255
vbase dd 0*256,1*256,2*256,3*256,4*256,5*256,6*256,7*256
dd 8*256,9*256,10*256,11*256,12*256,13*256,14*256,15*256
section .text
; ---- Win64 ABI: arguments in rcx,rdx,r8
global histo_asm_scalar4_core
histo_asm_scalar4_core:
mov r9, r8 ; original count
shr r8, 2 ; trip count
jz .tail
push rbx
; Scalar (4x)
; This is super-simple: do a single 64b (4B) scalar load
; and keep extracting bytes, shifting and incrementing the
; respectively histogram bin value.
;
; Keep 4 separate histograms: each 256-element histogram
; is 1k, so this is 4kB total; memory disambiguation for
; store->load forwarding only checks the bottom 12 bits
; of addresses on older Intel uArchs, so this is the largest
; number of separate histograms we can keep that will not
; result in false dependencies on these architectures.
; In either case, 4 histograms gets essentially all of the
; benefit of having more histograms.
;
; I tried using "movzx eax, bh" and only shifting once
; for every pair of bytes, but that was slightly more
; expensive on my SKX workstation. Similar with using
; a "movzx eax, bl" / "movzx eax, bh" / "bswap ebx" /
; "movzx eax, bh" / "movzx eax, bl" combo.
;
; Haven't investigated this further.
.inner:
mov ebx, [rdx]
add rdx, 4
movzx eax, bl
shr ebx, 8
add dword [rcx + rax*4 + 0*1024], 1
movzx eax, bl
shr ebx, 8
add dword [rcx + rax*4 + 1*1024], 1
movzx eax, bl
add dword [rcx + rax*4 + 2*1024], 1
shr ebx, 8
add dword [rcx + rbx*4 + 3*1024], 1
dec r8
jnz .inner
pop rbx
.tail:
and r9, 3 ; masked count
jz .done
.taillp:
movzx eax, byte [rdx]
inc rdx
inc dword [rcx + rax*4]
dec r9
jnz .taillp
.done:
ret
; ----
global histo_asm_scalar8_core
histo_asm_scalar8_core:
mov r9, r8 ; original count
shr r8, 3 ; trip count
jz .tail
push rbx
; Scalar (8x)
; Using a 64b (8B) load this time, but still 4 histogram
; slices.
.inner:
mov rbx, [rdx]
add rdx, 8
movzx eax, bl
shr rbx, 8
add dword [rcx + rax*4 + 0*1024], 1
movzx eax, bl
shr rbx, 8
add dword [rcx + rax*4 + 1*1024], 1
movzx eax, bl
shr rbx, 8
add dword [rcx + rax*4 + 2*1024], 1
movzx eax, bl
shr rbx, 8
add dword [rcx + rax*4 + 3*1024], 1
movzx eax, bl
shr ebx, 8
add dword [rcx + rax*4 + 0*1024], 1
movzx eax, bl
shr ebx, 8
add dword [rcx + rax*4 + 1*1024], 1
movzx eax, bl
add dword [rcx + rax*4 + 2*1024], 1
shr ebx, 8
add dword [rcx + rbx*4 + 3*1024], 1
dec r8
jnz .inner
pop rbx
.tail:
and r9, 7 ; masked count
jz .done
.taillp:
movzx eax, byte [rdx]
inc rdx
inc dword [rcx + rax*4]
dec r9
jnz .taillp
.done:
ret
; ----
global histo_asm_scalar8_var_core
histo_asm_scalar8_var_core:
push rsi
mov rsi, rdx
sub r8, 16 ; for readahead
jb .tail
push rbx
; Scalar (8x)
; Using two 32b (4B) loads, but offset by an iteration
; still 4 histogram slices
mov ebx, [rsi]
mov edx, [rsi + 4]
.inner:
add rsi, 8
movzx eax, bl
shr ebx, 8
add dword [rcx + rax*4 + 0*1024], 1
movzx eax, bl
shr ebx, 8
add dword [rcx + rax*4 + 1*1024], 1
movzx eax, bl
add dword [rcx + rax*4 + 2*1024], 1
shr ebx, 8
add dword [rcx + rbx*4 + 3*1024], 1
mov ebx, [rsi]
movzx eax, dl
shr edx, 8
add dword [rcx + rax*4 + 0*1024], 1
movzx eax, dl
shr edx, 8
add dword [rcx + rax*4 + 1*1024], 1
movzx eax, dl
add dword [rcx + rax*4 + 2*1024], 1
shr edx, 8
add dword [rcx + rdx*4 + 3*1024], 1
mov edx, [rsi + 4]
sub r8, 8
ja .inner
pop rbx
.tail:
add r8, 16 ; restore count
jz .done
.taillp:
movzx eax, byte [rsi]
inc rsi
inc dword [rcx + rax*4]
dec r8
jnz .taillp
.done:
pop rsi
ret
; ----
; Increments value in histogram bucket by 1
%macro INC1 2 ; slice, bucket
inc dword [rcx + (%1 % NUM_SLICES)*BYTES_PER_SLICE + %2*4]
%endmacro
; Increments value in histogram bucket by 2
%macro INC2 2 ; slice, bucket
add dword [rcx + (%1 % NUM_SLICES)*BYTES_PER_SLICE + %2*4], 2
%endmacro
; Updates sum for 8 bytes
%macro SUM8 5 ; reg64 reg32 reg8L reg8H INC
movzx eax, %3
%5 0, rax
movzx eax, %4
shr %1, 16
%5 1, rax
movzx eax, %3
%5 2, rax
movzx eax, %4
shr %1, 16
%5 3, rax
movzx eax, %3
%5 4, rax
movzx eax, %4
shr %2, 16
%5 5, rax
movzx eax, %3
%5 6, rax
movzx eax, %4
%5 7, rax
%endmacro
%define NUM_SLICES 4
%define BYTES_PER_SLICE 1024
global histo_asm_scalar8_var2_core
histo_asm_scalar8_var2_core:
push rsi
push r10
push rbx
lea rsi, [rdx + r8 - 24]
neg r8
add r8, 24 ; for readahead
jge .tail
; Scalar (8x)
; Using a single 64b (8B) load, but reading ahead by one iteration
; unrolled 2x to avoid moves
; still 4 histogram slices
mov rbx, [rsi + r8]
align 16
.inner:
mov rdx, [rsi + r8 + 8]
SUM8 rbx, ebx, bl, bh, INC1
mov rbx, [rsi + r8 + 16]
SUM8 rdx, edx, dl, dh, INC1
add r8, 16
jl .inner
.tail:
add rsi, 24
sub r8, 24 ; restore count
.taillp:
movzx eax, byte [rsi + r8]
inc dword [rcx + rax*4]
inc r8
jnz .taillp
.done:
pop rbx
pop r10
pop rsi
ret
%undef NUM_SLICES
%undef BYTES_PER_SLICE
; ----
%define NUM_SLICES 8
%define BYTES_PER_SLICE 1024
global histo_asm_scalar8_var3_core
histo_asm_scalar8_var3_core:
push rsi
push r10
push rbx
lea rsi, [rdx + r8 - 24]
neg r8
add r8, 24 ; for readahead
jge .tail
; Scalar (8x)
; Using a single 64b (8B) load, but reading ahead by one iteration
; unrolled 2x to avoid moves
; this one uses 8 histogram slices
mov rbx, [rsi + r8]
align 16
.inner:
mov rdx, [rsi + r8 + 8]
SUM8 rbx, ebx, bl, bh, INC1
mov rbx, [rsi + r8 + 16]
SUM8 rdx, edx, dl, dh, INC1
add r8, 16
jl .inner
.tail:
add rsi, 24
sub r8, 24 ; restore count
.taillp:
movzx eax, byte [rsi + r8]
inc dword [rcx + rax*4]
inc r8
jnz .taillp
.done:
pop rbx
pop r10
pop rsi
ret
%undef NUM_SLICES
%undef BYTES_PER_SLICE
; ----
%define NUM_SLICES 8
%define BYTES_PER_SLICE (260*4)
global histo_asm_scalar8_var4_core
histo_asm_scalar8_var4_core:
push rsi
push r10
push rbx
lea rsi, [rdx + r8 - 24]
neg r8
add r8, 24 ; for readahead
jge .tail
; Scalar (8x)
; Using a single 64b (8B) load, but reading ahead by one iteration
; unrolled 2x to avoid moves
; this one uses 8 histogram slices, and histograms have 260 slots
; not 256 to avoid 4k aliasing on runs of the same symbol
; (aliasing on descending runs seems less likely to occur in practice)
mov rbx, [rsi + r8]
align 16
.inner:
mov rdx, [rsi + r8 + 8]
SUM8 rbx, ebx, bl, bh, INC1
mov rbx, [rsi + r8 + 16]
SUM8 rdx, edx, dl, dh, INC1
add r8, 16
jl .inner
.tail:
add rsi, 24
sub r8, 24 ; restore count
.taillp:
movzx eax, byte [rsi + r8]
inc dword [rcx + rax*4]
inc r8
jnz .taillp
.done:
pop rbx
pop r10
pop rsi
ret
%undef NUM_SLICES
%undef BYTES_PER_SLICE
; ----
%define NUM_SLICES 4
%define BYTES_PER_SLICE (256*4)
global histo_asm_scalar8_var5_core
histo_asm_scalar8_var5_core:
push rsi
push r10
push rbx
lea rsi, [rdx + r8 - 24]
neg r8
add r8, 24 ; for readahead
jge .tail
; Scalar (8x)
; Using a single 64b (8B) load, but reading ahead by one iteration
; unrolled 2x to avoid moves
; back to 4 slices, but with a path to detect and handle long runs,
; since they are frequent for the original use case of this loop
mov rbx, [rsi + r8]
align 16
.inner:
mov rdx, [rsi + r8 + 8]
cmp rbx, rdx ; if next 64b match current 64b, can save work! (Also, check for runs.)
je .add_double
SUM8 rbx, ebx, bl, bh, INC1
mov rbx, [rsi + r8 + 16]
SUM8 rdx, edx, dl, dh, INC1
add r8, 16
jl .inner
.tail:
add rsi, 24
sub r8, 24 ; restore count
.taillp:
movzx eax, byte [rsi + r8]
inc dword [rcx + rax*4]
inc r8
jnz .taillp
.done:
pop rbx
pop r10
pop rsi
ret
align 16
.add_double:
; if we get here, we have a pair where the first and second 8 bytes
; in a group are identical; this is most likely a run of identical bytes,
; but even if not, we can still save a good deal of work
; check for sequence of all the same byte
; x ^ (x << 8) XORs every byte (but byte 0) with its predecessor
; if that is <256, all 16 bytes are the same value
shl rdx, 8
xor rdx, rbx
cmp rdx, 256
jnb .not_all_same
; add 16 to single bucket (just always use first slice)
add dword [rcx + rdx*4], 16
mov rbx, [rsi + r8 + 16]
add r8, 16
jl .inner
jmp .tail
.not_all_same:
; add 2 to each bucket since we confirmed we have the same 8 bytes twice
SUM8 rbx, ebx, bl, bh, INC2
mov rbx, [rsi + r8 + 16]
add r8, 16
jl .inner
jmp .tail
%undef NUM_SLICES
%undef BYTES_PER_SLICE
; ----
global histo_asm_sse4_core
histo_asm_sse4_core:
mov r9, r8 ; original count
shr r8, 4 ; trip count
jz .tail
; SSE4
; Same idea as before, still 4 histogram slices,
; but use a 16-byte SSE load and PEXTRB.
;
; This should not make much of a difference to the
; above (PEXTRB is 2 uops on the Intel uArchs I looked at,
; same as the movzx/shift pairs above), and indeed, in my
; tests anyway, it doesn't.
.inner:
movdqu xmm0, [rdx]
add rdx, 16
%assign i 0
%rep 16
pextrb eax, xmm0, i
add dword [rcx + rax*4 + (i&3)*256*4], 1
%assign i i+1
%endrep
dec r8
jnz .inner
.tail:
and r9, 15 ; masked count
jz .done
.taillp:
movzx eax, byte [rdx]
inc rdx
inc dword [rcx + rax*4]
dec r9
jnz .taillp
.done:
ret
; ----
global histo_asm_avx256_8x_core1
histo_asm_avx256_8x_core1:
vzeroupper
mov r9, r8 ; original count
shr r8, 5 ; trip count
jz .tail
vpbroadcastd ymm3, [rel vone]
vmovdqu ymm4, [rel vbase]
vpbroadcastd ymm5, [rel vmaskb]
; "AVX-256", variant 1.
;
; This is AVX-512, but "only" using 256-bit wide vectors
; we're not using "heavy" instructions, so 256-bit stays
; in the highest frequency level license on SKX.
; (The reason to use AVX-512 is because we need both
; gathers and scatters.)
;
; This avoids collisions between the 8 lanes by using
; 8 separate histograms. This would be bad on older
; uArchs with 4k store->load aliasing, but those
; certainly don't support AVX-512, so it's not a
; concern.
.inner:
; We treat these 32 bytes as 8 DWords
vmovdqu ymm0, [rdx]
add rdx, 32
; Extract bottom byte of every lane
vpandd ymm1, ymm0, ymm5
; Add base so the lanes all go into different histograms
; (so they never conflict)
vpaddd ymm1, ymm1, ymm4
; Dependency breaker
vpxord ymm2, ymm2, ymm2
; Gather the histo bin values
kxnorb k1, k1, k1
vpgatherdd ymm2{k1}, [rcx + ymm1*4]
; Increment
vpaddd ymm2, ymm2, ymm3
; Scatter updated values
kxnorb k1, k1, k1
vpscatterdd [rcx + ymm1*4]{k1}, ymm2
; Bits [15:8] of every lane
vpsrld ymm1, ymm0, 8
vpandd ymm1, ymm1, ymm5
vpaddd ymm1, ymm1, ymm4
vpxord ymm2, ymm2, ymm2
kxnorb k1, k1, k1
vpgatherdd ymm2{k1}, [rcx + ymm1*4]
vpaddd ymm2, ymm2, ymm3
kxnorb k1, k1, k1
vpscatterdd [rcx + ymm1*4]{k1}, ymm2
; Bits [23:16] of every lane
vpsrld ymm1, ymm0, 16
vpandd ymm1, ymm1, ymm5
vpaddd ymm1, ymm1, ymm4
vpxord ymm2, ymm2, ymm2
kxnorb k1, k1, k1
vpgatherdd ymm2{k1}, [rcx + ymm1*4]
vpaddd ymm2, ymm2, ymm3
kxnorb k1, k1, k1
vpscatterdd [rcx + ymm1*4]{k1}, ymm2
; Bits [31:24] of every lane
vpsrld ymm1, ymm0, 24
vpaddd ymm1, ymm1, ymm4
vpxord ymm2, ymm2, ymm2
kxnorb k1, k1, k1
vpgatherdd ymm2{k1}, [rcx + ymm1*4]
vpaddd ymm2, ymm2, ymm3
kxnorb k1, k1, k1
vpscatterdd [rcx + ymm1*4]{k1}, ymm2
dec r8
jnz .inner
.tail:
and r9, 31 ; masked count
jz .done
.taillp:
movzx eax, byte [rdx]
inc rdx
inc dword [rcx + rax*4]
dec r9
jnz .taillp
.done:
vzeroupper
ret
; ----
global histo_asm_avx256_8x_core2
histo_asm_avx256_8x_core2:
vzeroupper
mov r9, r8 ; original count
shr r8, 3 ; trip count
jz .tail
vpbroadcastd ymm3, [rel vone]
vmovdqu ymm4, [rel vbase]
; "AVX-256", variant 2
;
; Only do a single gather/scatter per iteration, and
; use vpmovzxbd to upconvert the 8 bytes.
;
; This is definitely worse than variant 1.
.inner:
vpmovzxbd ymm0, [rdx]
add rdx, 8
kxnorb k1, k1, k1
vpaddd ymm0, ymm0, ymm4
vpxord ymm2, ymm2, ymm2
vpgatherdd ymm2{k1}, [rcx + ymm0*4]
vpaddd ymm2, ymm2, ymm3
kxnorb k1, k1, k1
vpscatterdd [rcx + ymm0*4]{k1}, ymm2
dec r8
jnz .inner
.tail:
and r9, 7 ; masked count
jz .done
.taillp:
movzx eax, byte [rdx]
inc rdx
inc dword [rcx + rax*4]
dec r9
jnz .taillp
.done:
vzeroupper
ret
; ----
global histo_asm_avx256_8x_core3
histo_asm_avx256_8x_core3:
vzeroupper
mov r9, r8 ; original count
shr r8, 5 ; trip count
jz .tail
vpbroadcastd ymm3, [rel vone]
vpbroadcastd ymm17, [rel vtwo]
vmovdqu ymm4, [rel vbase]
vpbroadcastd ymm5, [rel vmaskb]
; "AVX-256", variant 3
;
; This is similar to variant 1, but manually
; checks whether adjacent batches of 8 collide;
; if so, the first "hit" in the batch updates the
; histogram count by 2, and the second is disabled.
; The idea here is to avoid store->load forwarding
; cases by hand.
;
; This is sometimes marginally better than
; variant 1.
.inner:
vmovdqu ymm0, [rdx]
add rdx, 32
vpandd ymm1, ymm0, ymm5
vpsrld ymm16, ymm0, 8
vpandd ymm16, ymm16, ymm5
vpcmpd k2, ymm1, ymm16, 0 ; second iter matches first iter?
vpblendmd ymm18{k2}, ymm3, ymm17 ; second_matches ? 2 : 1
vpaddd ymm1, ymm1, ymm4
vpxord ymm2, ymm2, ymm2
kxnorb k1, k1, k1
vpgatherdd ymm2{k1}, [rcx + ymm1*4]
vpaddd ymm2, ymm2, ymm18
kxnorb k1, k1, k1
vpscatterdd [rcx + ymm1*4]{k1}, ymm2
vpaddd ymm16, ymm16, ymm4
vpxord ymm18, ymm18, ymm18
knotb k1, k2
vpgatherdd ymm18{k1}, [rcx + ymm16*4]
vpaddd ymm18, ymm18, ymm3
knotb k1, k2
vpscatterdd [rcx + ymm16*4]{k1}, ymm18
vpsrld ymm1, ymm0, 16
vpandd ymm1, ymm1, ymm5
vpsrld ymm16, ymm0, 24
vpcmpd k2, ymm1, ymm16, 0 ; second iter matches first iter?
vpblendmd ymm18{k2}, ymm3, ymm17 ; second_matches ? 2 : 1
vpaddd ymm1, ymm1, ymm4
vpxord ymm2, ymm2, ymm2
kxnorb k1, k1, k1
vpgatherdd ymm2{k1}, [rcx + ymm1*4]
vpaddd ymm2, ymm2, ymm18
kxnorb k1, k1, k1
vpscatterdd [rcx + ymm1*4]{k1}, ymm2
vpaddd ymm16, ymm16, ymm4
vpxord ymm2, ymm2, ymm2
knotb k1, k2
vpgatherdd ymm2{k1}, [rcx + ymm16*4]
vpaddd ymm2, ymm2, ymm3
knotb k1, k2
vpscatterdd [rcx + ymm16*4]{k1}, ymm2
dec r8
jnz .inner
.tail:
and r9, 31 ; masked count
jz .done
.taillp:
movzx eax, byte [rdx]
inc rdx
inc dword [rcx + rax*4]
dec r9
jnz .taillp
.done:
vzeroupper
ret
; ----
global histo_asm_avx512_core_conflict
histo_asm_avx512_core_conflict:
vzeroupper
mov r9, r8 ; original count
shr r8, 4 ; trip count
jz .tail
vpbroadcastd zmm4, [rel vone]
vpbroadcastd zmm5, [rel v31]
vpbroadcastd zmm17, [rel vneg1]
; AVX-512, conflict detect
;
; This is the algorithm from the Intel Optimization Manual (example 15-18)
; with typos fixed.
;
; This one only writes to a single histogram slice, which given how much
; extra work that implies is probably not a good idea.
.inner:
vpmovzxbd zmm0, [rdx]
add rdx, 16
vpconflictd zmm1, zmm0 ; zmm1=conflicts
vmovaps zmm3, zmm4 ; vOne
vpxord zmm2, zmm2, zmm2
kxnorw k1, k1, k1
vpgatherdd zmm2{k1}, [rcx + zmm0*4]
vptestmd k1, zmm1, zmm1
kortestw k1, k1
jz .update
vplzcntd zmm1, zmm1
vpsubd zmm1, zmm5, zmm1
.conflicts:
vpermd zmm16{k1}{z}, zmm1, zmm3
vpermd zmm1{k1}, zmm1, zmm1
vpaddd zmm3{k1}, zmm3, zmm16
vpcmpd k1, zmm1, zmm17, 4
kortestw k1, k1
jnz .conflicts
.update:
vpaddd zmm2, zmm2, zmm3
kxnorw k1, k1, k1
vpscatterdd [rcx + zmm0*4]{k1}, zmm2
dec r8
jnz .inner
.tail:
and r9, 15 ; masked count
jz .done
.taillp:
movzx eax, byte [rdx]
inc rdx
inc dword [rcx + rax*4]
dec r9
jnz .taillp
.done:
vzeroupper
ret
; ----
global histo_asm_avx512_core_8x
histo_asm_avx512_core_8x:
vzeroupper
mov r9, r8 ; original count
shr r8, 4 ; trip count
jz .tail
mov eax, 0xff
kmovw k3, eax ; "low half of lanes" mask
vbroadcasti32x8 zmm4, [rel vbase]
vpbroadcastd zmm5, [rel vone]
vpaddd zmm16, zmm5, zmm5 ; two
; AVX-512, 8 histogram slices
;
; This means we can do the conflict detection manually like in var3 above.
; That seems preferable overall, and is indeed significantly faster in
; my tests.
.inner:
vpmovzxbd zmm0, [rdx]
add rdx, 16
vpaddd zmm1, zmm0, zmm4 ; add base lanes
valignd zmm2, zmm0, zmm0, 8 ; input bytes rotated by 8 lanes
vpcmpd k2{k3}, zmm0, zmm2, 0 ; check whether high half matches low half
; grab source bin values
vpxord zmm2, zmm2, zmm2
kxnorw k1, k1, k1
vpgatherdd zmm2{k1}, [rcx + zmm1*4]
; depending on whether we have a conflict between matching high and low lanes,
; add either 1 or 2
vpblendmd zmm0{k2}, zmm5, zmm16
vpaddd zmm2, zmm2, zmm0
; determine output scatter mask
kshiftlw k1, k2, 8
knotw k1, k1
vpscatterdd [rcx + zmm1*4]{k1}, zmm2
dec r8
jnz .inner
.tail:
and r9, 15 ; masked count
jz .done
.taillp:
movzx eax, byte [rdx]
inc rdx
inc dword [rcx + rax*4]
dec r9
jnz .taillp
.done:
vzeroupper
ret
; ----
global histo_asm_avx512_core_16x
histo_asm_avx512_core_16x:
vzeroupper
mov r9, r8 ; original count
shr r8, 4 ; trip count
jz .tail
vmovdqu32 zmm4, [rel vbase]
vpbroadcastd zmm5, [rel vone]
; AVX-512, 16 histogram slices
;
; No conflict detection necessary.
.inner:
vpmovzxbd zmm0, [rdx]
add rdx, 16
vpaddd zmm1, zmm0, zmm4 ; add base lanes
; grab source bin values
vpxord zmm2, zmm2, zmm2
kxnorw k1, k1, k1
vpgatherdd zmm2{k1}, [rcx + zmm1*4]
; increment
vpaddd zmm2, zmm2, zmm5
kxnorw k1, k1, k1
vpscatterdd [rcx + zmm1*4]{k1}, zmm2
dec r8
jnz .inner
.tail:
and r9, 15 ; masked count
jz .sumit
.taillp:
movzx eax, byte [rdx]
inc rdx
inc dword [rcx + rax*4]
dec r9
jnz .taillp
.sumit:
; summing loop
mov r9, (256-16)*4
.sum:
vmovdqa32 zmm0, [rcx + r9 + 0*1024]
%assign i 1
%rep 15
vpaddd zmm0, zmm0, [rcx + r9 + i*1024]
%assign i i+1
%endrep
vmovdqa32 [rcx + r9 + 0*1024], zmm0
sub r9, 16*4
jns .sum
.done:
vzeroupper
ret
; ----
global histo_asm_avx512_core_16x_var2
histo_asm_avx512_core_16x_var2:
vzeroupper
mov r9, r8 ; original count
shr r8, 5 ; trip count
jz .tail
vmovdqu32 zmm4, [rel vbase]
vpbroadcastd zmm5, [rel vone]
vpaddd zmm16, zmm5, zmm5 ; two
; AVX-512, 16 histogram slices
;
; Grab two runs of 16, detect conflicts between them.
.inner:
vpmovzxbd zmm0, [rdx]
vpmovzxbd zmm1, [rdx + 16]
add rdx, 32
vpcmpd k2, zmm0, zmm1, 0 ; check whether upper 16 match lower 16
vpaddd zmm0, zmm0, zmm4 ; add base lanes
vpaddd zmm1, zmm1, zmm4
; grab source bin values
vpxord zmm2, zmm2, zmm2
kxnorw k1, k1, k1
vpgatherdd zmm2{k1}, [rcx + zmm0*4]
; depending on whether we have a conflict between matching high and low lanes,
; add either 1 or 2
vpblendmd zmm3{k2}, zmm5, zmm16
vpaddd zmm2, zmm2, zmm3
; store updated
kxnorw k1, k1, k1
vpscatterdd [rcx + zmm0*4]{k1}, zmm2
; grab source bin values for second half
vpxord zmm2, zmm2, zmm2
knotw k1, k2
vpgatherdd zmm2{k1}, [rcx + zmm1*4]
; increment
vpaddd zmm2, zmm2, zmm5
; store updated
knotw k1, k2
vpscatterdd [rcx + zmm1*4]{k1}, zmm2
dec r8
jnz .inner
.tail:
and r9, 31 ; masked count
jz .sumit
.taillp:
movzx eax, byte [rdx]
inc rdx
inc dword [rcx + rax*4]
dec r9
jnz .taillp
.sumit:
; summing loop
mov r9, (256-16)*4
.sum:
vmovdqa32 zmm0, [rcx + r9 + 0*1024]
%assign i 1
%rep 15
vpaddd zmm0, zmm0, [rcx + r9 + i*1024]
%assign i i+1
%endrep
vmovdqa32 [rcx + r9 + 0*1024], zmm0
sub r9, 16*4
jns .sum
.done:
vzeroupper
ret
; ----
global histo_asm_avx512_core_16x_var3
histo_asm_avx512_core_16x_var3:
vzeroupper
mov r9, r8 ; original count
shr r8, 5 ; trip count
jz .tail
vmovdqu32 zmm4, [rel vbase]
vpbroadcastd zmm5, [rel vone]
vpaddd zmm16, zmm5, zmm5 ; two
vpmovzxbd zmm17, [rdx]
vpmovzxbd zmm18, [rdx + 16]
; AVX-512, 16 histogram slices
;
; Grab two runs of 16, detect conflicts between them.
.inner:
vpcmpd k2, zmm17, zmm18, 0 ; check whether upper 16 match lower 16
vpaddd zmm0, zmm17, zmm4 ; add base lanes
vpaddd zmm1, zmm18, zmm4
; loads for next round
vpmovzxbd zmm17, [rdx + 32]
vpmovzxbd zmm18, [rdx + 48]
add rdx, 32
; grab source bin values
vpxord zmm2, zmm2, zmm2
kxnorw k1, k1, k1
vpgatherdd zmm2{k1}, [rcx + zmm0*4]
; depending on whether we have a conflict between matching high and low lanes,
; add either 1 or 2
vpblendmd zmm3{k2}, zmm5, zmm16
vpaddd zmm2, zmm2, zmm3
; store updated
kxnorw k1, k1, k1
vpscatterdd [rcx + zmm0*4]{k1}, zmm2
; grab source bin values for second half
vpxord zmm2, zmm2, zmm2
knotw k1, k2
vpgatherdd zmm2{k1}, [rcx + zmm1*4]
; increment
vpaddd zmm2, zmm2, zmm5
; store updated
knotw k1, k2
vpscatterdd [rcx + zmm1*4]{k1}, zmm2
dec r8
jnz .inner
.tail:
and r9, 31 ; masked count
jz .sumit
.taillp:
movzx eax, byte [rdx]
inc rdx
inc dword [rcx + rax*4]
dec r9
jnz .taillp
.sumit:
; summing loop
mov r9, (256-16)*4
.sum:
vmovdqa32 zmm0, [rcx + r9 + 0*1024]
%assign i 1
%rep 15
vpaddd zmm0, zmm0, [rcx + r9 + i*1024]
%assign i i+1
%endrep
vmovdqa32 [rcx + r9 + 0*1024], zmm0
sub r9, 16*4
jns .sum
.done:
vzeroupper
ret
#include <stdio.h>
#include <stdint.h>
#include <string.h>
#ifdef _MSC_VER
#include <intrin.h>
#endif
#include <emmintrin.h>
typedef uint8_t U8;
typedef uint32_t U32;
typedef uint64_t U64;
//#define NOVALIDATE
static U8 *read_file(const char *filename, size_t *out_size)
{
U8 *ret = NULL;
FILE *f = fopen(filename, "rb");
if (f)
{
fseek(f, 0, SEEK_END);
long size = ftell(f);
fseek(f, 0, SEEK_SET);
U8 *mem = new U8[size];
if (mem)
{
if (fread(mem, size, 1, f) == 1)
{
if (out_size)
*out_size = size;
ret = mem;
}
else
delete[] mem;
}
fclose(f);
}
return ret;
}
static inline U64 cycle_timer()
{
#ifdef _MSC_VER
return __rdtsc();
#else
U32 lo, hi;
__asm__ volatile("rdtsc" : "=a"(lo), "=d"(hi) );
return lo | ((U64)hi << 32);
#endif
}
static inline uint32_t read32(const void *p)
{
uint32_t x;
memcpy(&x, p, sizeof(x));
return x;
}
static void histo_ref(U32 * counts, const U8 * rawArray, size_t rawLen)
{
memset(counts,0,256*sizeof(U32));
for (size_t i = 0; i < rawLen; i++)
counts[rawArray[i]]++;
}
static void histo_cpp_1x(U32 * counts, const U8 * rawArray, size_t rawLen)
{
memset(counts,0,256*sizeof(U32));
const U8 * rawPtr = rawArray;
const U8 * rawEnd = rawArray+rawLen;
const U8 * rawEndMul4 = rawArray+(rawLen&~3);
while(rawPtr < rawEndMul4)
{
U32 x = read32(rawPtr);
counts[x & 0xff]++; x >>= 8;
counts[x & 0xff]++; x >>= 8;
counts[x & 0xff]++; x >>= 8;
counts[x] ++; // last one doesn't need to mask
rawPtr += 4;
}
// finish the last few bytes (just throw them into array 0, doesn't matter)
while(rawPtr < rawEnd)
counts[ *rawPtr++ ] ++;
}
static void histo_cpp_2x(U32 * counts, const U8 * rawArray, size_t rawLen)
{
U32 __declspec(align(16)) countsArray[2][256];
memset(countsArray,0,sizeof(countsArray));
const U8 * rawPtr = rawArray;
const U8 * rawEnd = rawArray+rawLen;
const U8 * rawEndMul4 = rawArray+(rawLen&~3);
while(rawPtr < rawEndMul4)
{
U32 x = read32(rawPtr);
countsArray[0][x & 0xff]++; x >>= 8;
countsArray[1][x & 0xff]++; x >>= 8;
countsArray[0][x & 0xff]++; x >>= 8;
countsArray[1][x] ++; // last one doesn't need to mask
rawPtr += 4;
}
// finish the last few bytes (just throw them into array 0, doesn't matter)
while(rawPtr < rawEnd)
countsArray[0][ *rawPtr++ ] ++;
// sum the countsarrays together
int k=0;
for(;k<256;k+=4)
{
__m128i sum = _mm_load_si128((const __m128i *) &countsArray[0][k]);
sum = _mm_add_epi32(sum, *(const __m128i *) &countsArray[1][k]);
_mm_storeu_si128((__m128i *)&counts[k], sum);
}
}
static void finish_histo_4x(U32 * counts_out, const U32 * counts4x)
{
for (size_t k=0; k<256; k+=4)
{
__m128i sum = _mm_load_si128((const __m128i *) &counts4x[0*256 + k]);
sum = _mm_add_epi32(sum, *(const __m128i *) &counts4x[1*256 + k]);
sum = _mm_add_epi32(sum, *(const __m128i *) &counts4x[2*256 + k]);
sum = _mm_add_epi32(sum, *(const __m128i *) &counts4x[3*256 + k]);
_mm_storeu_si128((__m128i *)&counts_out[k], sum);
}
}
static void finish_histo_8x(U32 * counts_out, const U32 * counts8x)
{
for (size_t k=0; k<256; k+=4)
{
__m128i sum = _mm_load_si128((const __m128i *) &counts8x[0*256 + k]);
sum = _mm_add_epi32(sum, *(const __m128i *) &counts8x[1*256 + k]);
sum = _mm_add_epi32(sum, *(const __m128i *) &counts8x[2*256 + k]);
sum = _mm_add_epi32(sum, *(const __m128i *) &counts8x[3*256 + k]);
sum = _mm_add_epi32(sum, *(const __m128i *) &counts8x[4*256 + k]);
sum = _mm_add_epi32(sum, *(const __m128i *) &counts8x[5*256 + k]);
sum = _mm_add_epi32(sum, *(const __m128i *) &counts8x[6*256 + k]);
sum = _mm_add_epi32(sum, *(const __m128i *) &counts8x[7*256 + k]);
_mm_storeu_si128((__m128i *)&counts_out[k], sum);
}
}
static void finish_histo_8x_260(U32 * counts_out, const U32 * counts8x)
{
for (size_t k=0; k<256; k+=4)
{
__m128i sum = _mm_load_si128((const __m128i *) &counts8x[0*260 + k]);
sum = _mm_add_epi32(sum, *(const __m128i *) &counts8x[1*260 + k]);
sum = _mm_add_epi32(sum, *(const __m128i *) &counts8x[2*260 + k]);
sum = _mm_add_epi32(sum, *(const __m128i *) &counts8x[3*260 + k]);
sum = _mm_add_epi32(sum, *(const __m128i *) &counts8x[4*260 + k]);
sum = _mm_add_epi32(sum, *(const __m128i *) &counts8x[5*260 + k]);
sum = _mm_add_epi32(sum, *(const __m128i *) &counts8x[6*260 + k]);
sum = _mm_add_epi32(sum, *(const __m128i *) &counts8x[7*260 + k]);
_mm_storeu_si128((__m128i *)&counts_out[k], sum);
}
}
static void histo_cpp_4x(U32 * counts, const U8 * rawArray, size_t rawLen)
{
U32 __declspec(align(16)) countsArray[4*256];
memset(countsArray,0,sizeof(countsArray));
const U8 * rawPtr = rawArray;
const U8 * rawEnd = rawArray+rawLen;
const U8 * rawEndMul4 = rawArray+(rawLen&~3);
while(rawPtr < rawEndMul4)
{
U32 x = read32(rawPtr);
countsArray[0*256 + (x & 0xff)]++; x >>= 8;
countsArray[1*256 + (x & 0xff)]++; x >>= 8;
countsArray[2*256 + (x & 0xff)]++; x >>= 8;
countsArray[3*256 + x] ++; // last one doesn't need to mask
rawPtr += 4;
}
// finish the last few bytes (just throw them into slice 0, doesn't matter)
while(rawPtr < rawEnd)
countsArray[ *rawPtr++ ] ++;
finish_histo_4x(counts, countsArray);
}
extern "C" void histo_asm_scalar4_core(U32 * histo, const U8 * bytes, size_t nbytes);
extern "C" void histo_asm_scalar8_core(U32 * histo, const U8 * bytes, size_t nbytes);
extern "C" void histo_asm_scalar8_var_core(U32 * histo, const U8 * bytes, size_t nbytes);
extern "C" void histo_asm_scalar8_var2_core(U32 * histo, const U8 * bytes, size_t nbytes);
extern "C" void histo_asm_scalar8_var3_core(U32 * histo, const U8 * bytes, size_t nbytes);
extern "C" void histo_asm_scalar8_var4_core(U32 * histo, const U8 * bytes, size_t nbytes);
extern "C" void histo_asm_scalar8_var5_core(U32 * histo, const U8 * bytes, size_t nbytes);
extern "C" void histo_asm_sse4_core(U32 * histo, const U8 * bytes, size_t nbytes);
static void histo_asm_scalar4(U32 * counts, const U8 * rawArray, size_t rawLen)
{
U32 __declspec(align(16)) countsArray[4*256];
memset(countsArray,0,sizeof(countsArray));
histo_asm_scalar4_core(countsArray, rawArray, rawLen);
finish_histo_4x(counts, countsArray);
}
static void histo_asm_scalar8(U32 * counts, const U8 * rawArray, size_t rawLen)
{
U32 __declspec(align(16)) countsArray[4*256];
memset(countsArray,0,sizeof(countsArray));
histo_asm_scalar8_core(countsArray, rawArray, rawLen);
finish_histo_4x(counts, countsArray);
}
static void histo_asm_scalar8_var(U32 * counts, const U8 * rawArray, size_t rawLen)
{
U32 __declspec(align(16)) countsArray[4*256];
memset(countsArray,0,sizeof(countsArray));
histo_asm_scalar8_var_core(countsArray, rawArray, rawLen);
finish_histo_4x(counts, countsArray);
}
static void histo_asm_scalar8_var2(U32 * counts, const U8 * rawArray, size_t rawLen)
{
U32 __declspec(align(16)) countsArray[4*256];
memset(countsArray,0,sizeof(countsArray));
histo_asm_scalar8_var2_core(countsArray, rawArray, rawLen);
finish_histo_4x(counts, countsArray);
}
static void histo_asm_scalar8_var3(U32 * counts, const U8 * rawArray, size_t rawLen)
{
U32 __declspec(align(16)) countsArray[8*256];
memset(countsArray,0,sizeof(countsArray));
histo_asm_scalar8_var3_core(countsArray, rawArray, rawLen);
finish_histo_8x(counts, countsArray);
}
static void histo_asm_scalar8_var4(U32 * counts, const U8 * rawArray, size_t rawLen)
{
U32 __declspec(align(16)) countsArray[8*260];
memset(countsArray,0,sizeof(countsArray));
histo_asm_scalar8_var4_core(countsArray, rawArray, rawLen);
finish_histo_8x_260(counts, countsArray);
}
static void histo_asm_scalar8_var5(U32 * counts, const U8 * rawArray, size_t rawLen)
{
U32 __declspec(align(16)) countsArray[4*256];
memset(countsArray,0,sizeof(countsArray));
histo_asm_scalar8_var5_core(countsArray, rawArray, rawLen);
finish_histo_4x(counts, countsArray);
}
static void histo_asm_sse4(U32 * counts, const U8 * rawArray, size_t rawLen)
{
U32 __declspec(align(16)) countsArray[4*256];
memset(countsArray,0,sizeof(countsArray));
histo_asm_sse4_core(countsArray, rawArray, rawLen);
finish_histo_4x(counts, countsArray);
}
extern "C" void histo_asm_avx256_8x_core1(U32 * histo, const U8 * bytes, size_t nbytes);
extern "C" void histo_asm_avx256_8x_core2(U32 * histo, const U8 * bytes, size_t nbytes);
extern "C" void histo_asm_avx256_8x_core3(U32 * histo, const U8 * bytes, size_t nbytes);
static void histo_asm_avx256_8x_1(U32 * counts, const U8 * rawArray, size_t rawLen)
{
U32 __declspec(align(16)) countsArray[8*256];
memset(countsArray,0,sizeof(countsArray));
histo_asm_avx256_8x_core1(countsArray, rawArray, rawLen);
finish_histo_8x(counts, countsArray);
}
static void histo_asm_avx256_8x_2(U32 * counts, const U8 * rawArray, size_t rawLen)
{
U32 __declspec(align(16)) countsArray[8*256];
memset(countsArray,0,sizeof(countsArray));
histo_asm_avx256_8x_core2(countsArray, rawArray, rawLen);
finish_histo_8x(counts, countsArray);
}
static void histo_asm_avx256_8x_3(U32 * counts, const U8 * rawArray, size_t rawLen)
{
U32 __declspec(align(16)) countsArray[8*256];
memset(countsArray,0,sizeof(countsArray));
histo_asm_avx256_8x_core3(countsArray, rawArray, rawLen);
finish_histo_8x(counts, countsArray);
}
extern "C" void histo_asm_avx512_core_conflict(U32 * histo, const U8 * bytes, size_t nbytes);
extern "C" void histo_asm_avx512_core_8x(U32 * histo, const U8 * bytes, size_t nbytes);
extern "C" void histo_asm_avx512_core_16x(U32 * histo, const U8 * bytes, size_t nbytes);
extern "C" void histo_asm_avx512_core_16x_var2(U32 * histo, const U8 * bytes, size_t nbytes);
extern "C" void histo_asm_avx512_core_16x_var3(U32 * histo, const U8 * bytes, size_t nbytes);
static void histo_asm_avx512_conflict(U32 * counts, const U8 * rawArray, size_t rawLen)
{
memset(counts,0,256*sizeof(U32));
histo_asm_avx512_core_conflict(counts, rawArray, rawLen);
}
static void histo_asm_avx512_8x(U32 * counts, const U8 * rawArray, size_t rawLen)
{
U32 __declspec(align(16)) countsArray[8*256];
memset(countsArray,0,sizeof(countsArray));
histo_asm_avx512_core_8x(countsArray, rawArray, rawLen);
finish_histo_8x(counts, countsArray);
}
static void histo_asm_avx512_16x(U32 * counts, const U8 * rawArray, size_t rawLen)
{
U32 __declspec(align(64)) countsArray[16*256];
memset(countsArray,0,sizeof(countsArray));
histo_asm_avx512_core_16x(countsArray, rawArray, rawLen);
memcpy(counts, countsArray, 256*sizeof(U32));
}
static void histo_asm_avx512_16x_var2(U32 * counts, const U8 * rawArray, size_t rawLen)
{
U32 __declspec(align(64)) countsArray[16*256];
memset(countsArray,0,sizeof(countsArray));
histo_asm_avx512_core_16x_var2(countsArray, rawArray, rawLen);
memcpy(counts, countsArray, 256*sizeof(U32));
}
static void histo_asm_avx512_16x_var3(U32 * counts, const U8 * rawArray, size_t rawLen)
{
U32 __declspec(align(64)) countsArray[16*256];
memset(countsArray,0,sizeof(countsArray));
histo_asm_avx512_core_16x_var3(countsArray, rawArray, rawLen);
memcpy(counts, countsArray, 256*sizeof(U32));
}
typedef void histo_func(U32 * histo, const U8 * bytes, size_t nbytes);
U32 g_histo[256]; // global sink
static void testit(const char *name, histo_func *fn, const U8 * bytes, size_t nbytes)
{
#ifndef NOVALIDATE
// check correctness
U32 ref_histo[256];
U32 our_histo[256];
histo_ref(ref_histo, bytes, nbytes);
fn(our_histo, bytes, nbytes);
if (memcmp(ref_histo, our_histo, sizeof(our_histo)) != 0)
{
printf("%30s: incorrect result!\n", name);
return;
}
#endif
// check timing
U64 min_dur = ~(U64)0;
for (int run = 0; run < 7500; run++)
//for (int run = 0; run < 15000; run++)
//for (int run = 0; run < 50000; run++)
{
U64 start = cycle_timer();
fn(g_histo, bytes, nbytes);
U64 duration = cycle_timer() - start;
if (duration < min_dur)
min_dur = duration;
}
printf("%30s: %7lld (%.2f/byte)\n", name, min_dur, 1.0*min_dur/nbytes);
}
static void test_file(const char *label, const U8 *bytes, size_t size)
{
printf("%s: %d bytes\n", label, (int)size);
#define TESTIT(name) testit(#name,name,bytes,size)
#if 1
TESTIT(histo_ref);
TESTIT(histo_cpp_1x);
TESTIT(histo_cpp_2x);
TESTIT(histo_cpp_4x);
TESTIT(histo_asm_scalar4);
TESTIT(histo_asm_scalar8);
TESTIT(histo_asm_scalar8_var);
TESTIT(histo_asm_scalar8_var2);
TESTIT(histo_asm_scalar8_var3);
TESTIT(histo_asm_scalar8_var4);
TESTIT(histo_asm_scalar8_var5);
TESTIT(histo_asm_sse4);
TESTIT(histo_asm_avx256_8x_1);
TESTIT(histo_asm_avx256_8x_2);
TESTIT(histo_asm_avx256_8x_3);
TESTIT(histo_asm_avx512_conflict);
TESTIT(histo_asm_avx512_8x);
TESTIT(histo_asm_avx512_16x);
TESTIT(histo_asm_avx512_16x_var2);
TESTIT(histo_asm_avx512_16x_var3);
#endif
//TESTIT(histo_ref);
//TESTIT(histo_cpp_4x);
//TESTIT(histo_asm_scalar8_var2);
//TESTIT(histo_asm_scalar8_var3);
//TESTIT(histo_asm_scalar8_var4);
#undef TESTIT
}
int main(int argc, char **argv)
{
if (argc < 2)
{
fprintf(stderr, "Usage: histotest <filename>\n");
return 1;
}
const char *filename = argv[1];
size_t size;
U8 *bytes = read_file(filename, &size);
if (!bytes)
{
fprintf(stderr, "Error loading '%s'\n", filename);
return 1;
}
test_file(filename, bytes, size);
static U8 zeroes[128*1024];
test_file("zeroes", zeroes, 128*1024);
delete[] bytes;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment