Skip to content

Instantly share code, notes, and snippets.

@clausecker
Created September 2, 2020 14:19
Show Gist options
  • Save clausecker/42f4cc547c2eafa2c6ce6a493b4e2a73 to your computer and use it in GitHub Desktop.
Save clausecker/42f4cc547c2eafa2c6ce6a493b4e2a73 to your computer and use it in GitHub Desktop.
summing decimal digits with different approaches
// getnumericvalue(ptr)
.section .text
.type getnumericvalue, @function
.globl getnumericvalue
getnumericvalue:
xor %eax, %eax // digit counter
// process string until we reach cache-line alignment
test $64-1, %dil // is ptr aligned to 64 byte?
jz 0f
1: movzbl (%rdi), %edx // load a byte from the string
inc %rdi // advance pointer
test %edx, %edx // is this the NUL byte?
jz .Lend // if yes, finish this function
sub $'0', %edx // turn ASCII character into digit
add %edx, %eax // and add to counter
test $64-1, %dil // is ptr aligned to 64 byte?
jnz 1b // if not, process more data
// process data in cache line increments until the end
// of the string is found somewhere
0: vpbroadcastb zero(%rip), %zmm1 // mask of '0' characters
vpxor %xmm3, %xmm3, %xmm3 // vectorised digit counter
vmovdqa32 (%rdi), %zmm0 // load one cache line from the string
vptestmb %zmm0, %zmm0, %k0 // clear k0 bits if any byte is NUL
kortestq %k0, %k0 // clear CF if a NUL byte is found
jnc 0f // skip loop if a NUL byte is found
.align 16
1: add $64, %rdi // advance pointer
vpsadbw %zmm1, %zmm0, %zmm0 // sum groups of 8 bytes into 8 words
// also subtracts '0' from each byte
vpaddq %zmm3, %zmm0, %zmm3 // add to counters
vmovdqa32 (%rdi), %zmm0 // load one cache line from the string
vptestmb %zmm0, %zmm0, %k0 // clear k0 bits if any byte is NUL
kortestq %k0, %k0 // clear CF if a NUL byte is found
jc 1b // go on unless a NUL byte was found
// reduce 8 vectorised counters into rdx
0: vextracti64x4 $1, %zmm3, %ymm2 // extract high 4 words
vpaddq %ymm2, %ymm3, %ymm3 // and add them to the low words
vextracti128 $1, %ymm3, %xmm2 // extract high 2 words
vpaddq %xmm2, %xmm3, %xmm3 // and add them to the low words
vpshufd $0x4e, %xmm3, %xmm2 // swap qwords into xmm2
vpaddq %xmm2, %xmm3, %xmm3 // and add to xmm0
vmovq %xmm3, %rdx // move digit counter back to rdx
add %rdx, %rax // and add to counts from scalar head
// process tail
1: movzbl (%rdi), %edx // load a byte from the string
inc %rdi // advance pointer
test %edx, %edx // is this the NUL byte?
jz .Lend // if yes, finish this function
sub $'0', %edx // turn ASCII character into digit
add %rdx, %rax // and add to counter
jnz 1b // if not, process more data
.Lend: xor %edx, %edx // zero-extend RAX into RDX:RAX
mov $9, %ecx // divide by 9
div %rcx // perform division
mov %edx, %eax // move remainder to result register
test %eax, %eax // is the remainder zero?
cmovz %ecx, %eax // if yes, set remainder to 9
vzeroupper // restore SSE performance
ret // and return
.size getnumericvalue, .-getnumericvalue
// constants
.section .rodata
zero: .byte '0'
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <stdint.h>
extern unsigned getnumericvalue_simple(const char *in_str);
extern unsigned getnumericvalue_naive(const char *ptr);
extern unsigned getnumericvalue_parallel(const char *ptr);
extern unsigned getnumericvalue(const char *ptr);
static void measure(const char *name, int digits, const char *p, unsigned(*fun)(const char*)) {
clock_t start;
unsigned result = 0;
double duration;
int i, n = 10000;
start = clock();
for (i = 0; i < n; i++)
result += fun(p);
duration = (clock() - start) * 1000.0 / CLOCKS_PER_SEC;
printf("%-9s %d digits -> %u, %7.3f msec\n", name, digits, result, duration/n);
}
int main(int argc, char *argv[]) {
int digits = argc < 2 ? 1000000 : strtol(argv[1], NULL, 0);
char *p = malloc(digits + 1);
for (int i = 0; i < digits; i++)
p[i] = "0123456789123456"[i & 15];
p[digits] = '\0';
measure("simple", digits, p, getnumericvalue_simple);
measure("naive", digits, p, getnumericvalue_naive);
measure("parallel", digits, p, getnumericvalue_parallel);
measure("simd", digits, p, getnumericvalue);
return 0;
}
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <stdint.h>
unsigned getnumericvalue_simple(const char *in_str) {
unsigned long number = 0;
const char *ptr = in_str;
do {
if (*ptr != '9') number += (*ptr - '0'); // Exclude '9'
ptr++;
} while (*ptr != 0);
return number <= 9 ? number : ((number - 1) % 9) + 1;
}
unsigned getnumericvalue_naive(const char *ptr) {
unsigned long number = 0;
while (*ptr) {
number += *ptr++ - '0';
}
return number ? 1 + (number - 1) % 9 : 0;
}
unsigned getnumericvalue_parallel(const char *ptr) {
unsigned long long number = 0;
unsigned long long pack1, pack2, pack3;
/* align source on ull boundary */
while ((uintptr_t)ptr & (sizeof(unsigned long long) - 1)) {
if (*ptr == '\0')
return number ? 1 + (number - 1) % 9 : 0;
number += *ptr++ - '0';
}
/* scan 8 bytes at a time */
pack3 = 0x3030303030303030;
for (;;) {
pack1 = 0;
#define REP8(x) x;x;x;x;x;x;x;x
#define REP28(x) x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x
REP28(pack2 = *(const unsigned long long *)(const void *)ptr;
if ((pack2 & pack3) != pack3)
break;
ptr += sizeof(unsigned long long);
pack1 += pack2 - pack3);
REP8(number += pack1 & 0xFF; pack1 >>= 8);
}
REP8(number += pack1 & 0xFF; pack1 >>= 8);
/* finish trailing bytes */
while (*ptr) {
number += *ptr++ - '0';
}
return number ? 1 + (number - 1) % 9 : 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment