Skip to content

Instantly share code, notes, and snippets.

@clausecker
Created August 4, 2020 18:18
Show Gist options
  • Save clausecker/9ad5fe5589106d19bb09776727c1a768 to your computer and use it in GitHub Desktop.
Save clausecker/9ad5fe5589106d19bb09776727c1a768 to your computer and use it in GitHub Desktop.
Vectorised positional popcount for Go
#include "textflag.h"
// func PospopcntMem(counts *[8]int32, buf []byte)
TEXT ·PospopcntMem(SB),NOSPLIT,$0-32
MOVQ counts+0(FP), DI
MOVQ buf_base+8(FP), SI // SI = &buf[0]
MOVQ buf_len+16(FP), CX // CX = len(buf)
SUBQ $32, CX // pre-subtract 32 bit from CX
JL scalar
vector: VMOVDQU (SI), Y0 // load 32 bytes from buf
ADDQ $32, SI // advance SI past them
VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, 4*7(DI) // add to counter
VPADDD Y0, Y0, Y0 // shift Y0 left by one place
VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, 4*6(DI) // add to counter
VPADDD Y0, Y0, Y0 // shift Y0 left by one place
VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, 4*5(DI) // add to counter
VPADDD Y0, Y0, Y0 // shift Y0 left by one place
VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, 4*4(DI) // add to counter
VPADDD Y0, Y0, Y0 // shift Y0 left by one place
VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, 4*3(DI) // add to counter
VPADDD Y0, Y0, Y0 // shift Y0 left by one place
VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, 4*2(DI) // add to counter
VPADDD Y0, Y0, Y0 // shift Y0 left by one place
VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, 4*1(DI) // add to counter
VPADDD Y0, Y0, Y0 // shift Y0 left by one place
VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, 4*0(DI) // add to counter
SUBQ $32, CX
JGE vector // repeat as long as bytes are left
scalar: ADDQ $32, CX // undo last subtraction
JE done // if CX=0, there's nothing left
loop: MOVBLZX (SI), AX // load a byte from buf
INCQ SI // advance past it
BTL $0, AX // is bit 0 set?
ADCL $0, 4*0(DI) // add it to the counters
BTL $1, AX // is bit 1 set?
ADCL $0, 4*1(DI) // add it to the counters
BTL $2, AX // is bit 2 set?
ADCL $0, 4*2(DI) // add it to the counters
BTL $3, AX // is bit 3 set?
ADCL $0, 4*3(DI) // add it to the counters
BTL $4, AX // is bit 4 set?
ADCL $0, 4*4(DI) // add it to the counters
BTL $5, AX // is bit 5 set?
ADCL $0, 4*5(DI) // add it to the counters
BTL $6, AX // is bit 6 set?
ADCL $0, 4*6(DI) // add it to the counters
BTL $7, AX // is bit 7 set?
ADCL $0, 4*7(DI) // add it to the counters
DECQ CX // mark this byte as done
JNE loop // and proceed if any bytes are left
done: VZEROUPPER // restore SSE-compatibility
RET
package pospopcnt
import "math/rand"
import "testing"
import "strconv"
// test sizes
var testSizes = []int{ 10, 32, 1000, 2000, 4000, 10000, 100000, 10000000, 1000000000 }
func TestScalarReg(t *testing.T) {
testHarness(PospopcntScalarReg, t)
}
func TestScalarMem(t *testing.T) {
testHarness(PospopcntScalarMem, t)
}
func TestReg(t *testing.T) {
testHarness(PospopcntReg, t)
}
func TestMem(t *testing.T) {
testHarness(PospopcntMem, t)
}
func BenchmarkReference(b *testing.B) {
outerHarness(PospopcntReference, b)
}
func BenchmarkScalarReg(b *testing.B) {
outerHarness(PospopcntScalarReg, b)
}
func BenchmarkScalarMem(b *testing.B) {
outerHarness(PospopcntScalarMem, b)
}
func BenchmarkReg(b *testing.B) {
outerHarness(PospopcntReg, b)
}
func BenchmarkMem(b *testing.B) {
outerHarness(PospopcntMem, b)
}
// test harness: make sure the function does the same thing as the reference.
func testHarness(pospopcnt func(*[8]int32, []byte), t *testing.T) {
t.Helper()
buf := make([]byte, 12345) // an intentionally odd nmber
rand.Read(buf)
refCounts := [8]int32{1234112, 12341234, 5635635, 423452345, 2345232, 32452345, 23452452, 2342542,}
testCounts := refCounts
PospopcntReference(&refCounts, buf)
pospopcnt(&testCounts, buf)
if refCounts != testCounts {
t.Error("counts don't match")
}
}
// outer harness: run benchmarks on pospopcnt for various data sizes.
func outerHarness(pospopcnt func(*[8]int32, []byte), b *testing.B) {
for i := range testSizes {
b.Run(strconv.Itoa(testSizes[i]), func(b *testing.B) {
innerHarness(pospopcnt, b, testSizes[i])
})
}
}
// inner harness: benchmark harness for one test at one data size
func innerHarness(pospopcnt func(*[8]int32, []byte), b *testing.B, n int) {
b.Helper()
if n <= 0 {
b.Errorf("buffer size must be positive: %d", n)
}
b.SetBytes(int64(n))
buf := make([]byte, n)
rand.Read(buf)
var counts [8]int32
b.ResetTimer()
for i := 0; i < b.N; i++ {
pospopcnt(&counts, buf)
}
}
package pospopcnt
// vectorised positional popcount with counters in registers
func PospopcntReg(counts *[8]int32, buf []byte)
// vectorised positional popcount with counters in memory
func PospopcntMem(counts *[8]int32, buf []byte)
// scalar positional popcount with counters in registers
func PospopcntScalarReg(counts *[8]int32, buf []byte)
// scalar positional popcount with counters in memory
func PospopcntScalarMem(counts *[8]int32, buf []byte)
// positional popcount reference implementation
func PospopcntReference(counts *[8]int32, buf []byte) {
for i := 0; i < len(buf); i++ {
for j := 0; j < 8; j++ {
(*counts)[j] += int32(buf[i]) >> j & 1
}
}
}
#include "textflag.h"
// func PospopcntReg(counts *[8]int32, buf []byte)
TEXT ·PospopcntReg(SB),NOSPLIT,$0-32
MOVQ counts+0(FP), DI
MOVQ buf_base+8(FP), SI // SI = &buf[0]
MOVQ buf_len+16(FP), CX // CX = len(buf)
// load counts into register R8--R15
MOVL 4*0(DI), R8
MOVL 4*1(DI), R9
MOVL 4*2(DI), R10
MOVL 4*3(DI), R11
MOVL 4*4(DI), R12
MOVL 4*5(DI), R13
MOVL 4*6(DI), R14
MOVL 4*7(DI), R15
SUBQ $32, CX // pre-subtract 32 bit from CX
JL scalar
vector: VMOVDQU (SI), Y0 // load 32 bytes from buf
ADDQ $32, SI // advance SI past them
VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, R15 // add to counter
VPADDD Y0, Y0, Y0 // shift Y0 left by one place
VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, R14 // add to counter
VPADDD Y0, Y0, Y0 // shift Y0 left by one place
VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, R13 // add to counter
VPADDD Y0, Y0, Y0 // shift Y0 left by one place
VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, R12 // add to counter
VPADDD Y0, Y0, Y0 // shift Y0 left by one place
VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, R11 // add to counter
VPADDD Y0, Y0, Y0 // shift Y0 left by one place
VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, R10 // add to counter
VPADDD Y0, Y0, Y0 // shift Y0 left by one place
VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, R9 // add to counter
VPADDD Y0, Y0, Y0 // shift Y0 left by one place
VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, R8 // add to counter
SUBQ $32, CX
JGE vector // repeat as long as bytes are left
scalar: ADDQ $32, CX // undo last subtraction
JE done // if CX=0, there's nothing left
loop: MOVBLZX (SI), AX // load a byte from buf
INCQ SI // advance past it
BTL $0, AX // is bit 0 set?
ADCL $0, R8 // add it to R8
BTL $1, AX // is bit 1 set?
ADCL $0, R9 // add it to R9
BTL $2, AX // is bit 2 set?
ADCL $0, R10 // add it to R10
BTL $3, AX // is bit 3 set?
ADCL $0, R11 // add it to R11
BTL $4, AX // is bit 4 set?
ADCL $0, R12 // add it to R12
BTL $5, AX // is bit 5 set?
ADCL $0, R13 // add it to R13
BTL $6, AX // is bit 6 set?
ADCL $0, R14 // add it to R14
BTL $7, AX // is bit 7 set?
ADCL $0, R15 // add it to R15
DECQ CX // mark this byte as done
JNE loop // and proceed if any bytes are left
// write R8--R15 back to counts
done: MOVL R8, 4*0(DI)
MOVL R9, 4*1(DI)
MOVL R10, 4*2(DI)
MOVL R11, 4*3(DI)
MOVL R12, 4*4(DI)
MOVL R13, 4*5(DI)
MOVL R14, 4*6(DI)
MOVL R15, 4*7(DI)
VZEROUPPER // restore SSE-compatibility
RET
#include "textflag.h"
// func PospopcntScalarReg(counts *[8]int32, buf []byte)
TEXT ·PospopcntScalarReg(SB),NOSPLIT,$0-32
MOVQ counts+0(FP), DI
MOVQ buf_base+8(FP), SI // SI = &buf[0]
MOVQ buf_len+16(FP), CX // CX = len(buf)
// load counts into register R8--R15
MOVL 4*0(DI), R8
MOVL 4*1(DI), R9
MOVL 4*2(DI), R10
MOVL 4*3(DI), R11
MOVL 4*4(DI), R12
MOVL 4*5(DI), R13
MOVL 4*6(DI), R14
MOVL 4*7(DI), R15
TESTQ CX, CX
JE done // if CX=0, there's nothing left
loop: MOVBLZX (SI), AX // load a byte from buf
INCQ SI // advance past it
BTL $0, AX // is bit 0 set?
ADCL $0, R8 // add it to R8
BTL $1, AX // is bit 1 set?
ADCL $0, R9 // add it to R9
BTL $2, AX // is bit 2 set?
ADCL $0, R10 // add it to R10
BTL $3, AX // is bit 3 set?
ADCL $0, R11 // add it to R11
BTL $4, AX // is bit 4 set?
ADCL $0, R12 // add it to R12
BTL $5, AX // is bit 5 set?
ADCL $0, R13 // add it to R13
BTL $6, AX // is bit 6 set?
ADCL $0, R14 // add it to R14
BTL $7, AX // is bit 7 set?
ADCL $0, R15 // add it to R15
DECQ CX // mark this byte as done
JNE loop // and procced if any bytes are left
// write R8--R15 back to counts
done: MOVL R8, 4*0(DI)
MOVL R9, 4*1(DI)
MOVL R10, 4*2(DI)
MOVL R11, 4*3(DI)
MOVL R12, 4*4(DI)
MOVL R13, 4*5(DI)
MOVL R14, 4*6(DI)
MOVL R15, 4*7(DI)
RET
// func PospopcntScalarMem(counts *[8]int32, buf []byte)
TEXT ·PospopcntScalarMem(SB),NOSPLIT,$0-32
MOVQ counts+0(FP), DI
MOVQ buf_base+8(FP), SI // SI = &buf[0]
MOVQ buf_len+16(FP), CX // CX = len(buf)
TESTQ CX, CX
JE done // if CX=0, there's nothing left
loop: MOVBLZX (SI), AX // load a byte from buf
INCQ SI // advance past it
BTL $0, AX // is bit 0 set?
ADCL $0, 4*0(DI) // add it to the counters
BTL $1, AX // is bit 1 set?
ADCL $0, 4*1(DI) // add it to the counters
BTL $2, AX // is bit 2 set?
ADCL $0, 4*2(DI) // add it to the counters
BTL $3, AX // is bit 3 set?
ADCL $0, 4*3(DI) // add it to the counters
BTL $4, AX // is bit 4 set?
ADCL $0, 4*4(DI) // add it to the counters
BTL $5, AX // is bit 5 set?
ADCL $0, 4*5(DI) // add it to the counters
BTL $6, AX // is bit 6 set?
ADCL $0, 4*6(DI) // add it to the counters
BTL $7, AX // is bit 7 set?
ADCL $0, 4*7(DI) // add it to the counters
DECQ CX // mark this byte as done
JNE loop // and proceed if any bytes are left
done: RET
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment