Skip to content

Instantly share code, notes, and snippets.

@xen0n
Created July 21, 2023 12:57
Show Gist options
  • Save xen0n/09d333ed83b7716e3be50baad9be627a to your computer and use it in GitHub Desktop.
Save xen0n/09d333ed83b7716e3be50baad9be627a to your computer and use it in GitHub Desktop.
Sketching LoongArch SIMD acceleration for Linux XOR ops
// SPDX-License-Identifier: GPL-2.0-or-later
/*
$ gcc -O3 -o linux-xor-simd-test linux-xor-simd-test.c
$ ./linux-xor-simd-test
ref (size=4096 ) passed 16383 times: 0.005044150 s total, 0.000000307 s per pass, 12687.191 MiB/s
lsx_32b (size=4096 ) passed 16383 times: 0.002663250 s total, 0.000000162 s per pass, 24029.323 MiB/s
lsx_64b (size=4096 ) passed 16383 times: 0.002517970 s total, 0.000000153 s per pass, 25415.749 MiB/s
lsx_128b (size=4096 ) passed 16383 times: 0.002517590 s total, 0.000000153 s per pass, 25419.585 MiB/s
lasx_32b (size=4096 ) passed 16383 times: 0.001935550 s total, 0.000000118 s per pass, 33063.519 MiB/s
lasx_64b (size=4096 ) passed 16383 times: 0.001813990 s total, 0.000000110 s per pass, 35279.188 MiB/s
lasx_128b (size=4096 ) passed 16383 times: 0.001756910 s total, 0.000000107 s per pass, 36425.368 MiB/s
*/
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <sys/random.h>
// may not be widely available yet
// #include <lsxintrin.h>
// #define DATA_SIZE_MIN_ORDER 9 // 512
// #define DATA_SIZE_MAX_ORDER 20 // 1MiB
// same as crypto/xor.c
#define DATA_SIZE_MIN_ORDER 12 // 4KiB
#define DATA_SIZE_MAX_ORDER 12 // 4KiB
#define TIMES 16383 // must be odd
typedef void (*xor_impl_t)(void * __restrict, const void * __restrict, size_t);
// taken from linux include/asm-generic/xor.h
static void
xor_8regs_2(unsigned long bytes, unsigned long * __restrict p1,
const unsigned long * __restrict p2)
{
long lines = bytes / (sizeof (long)) / 8;
do {
p1[0] ^= p2[0];
p1[1] ^= p2[1];
p1[2] ^= p2[2];
p1[3] ^= p2[3];
p1[4] ^= p2[4];
p1[5] ^= p2[5];
p1[6] ^= p2[6];
p1[7] ^= p2[7];
p1 += 8;
p2 += 8;
} while (--lines > 0);
}
static void reference_xor(void * __restrict a, const void * __restrict b, size_t len)
{
xor_8regs_2(len, a, b);
}
//
// LSX
//
static void
xor_lsx_32b(unsigned long bytes, void * __restrict p1, const void * __restrict p2)
{
long lines = bytes / 32;
do {
asm volatile (
"vld $vr0, %[dst], 0\n\t"
"vld $vr1, %[dst], 16\n\t"
"vld $vr2, %[src], 0\n\t"
"vld $vr3, %[src], 16\n\t"
"vxor.v $vr0, $vr0, $vr2\n\t"
"vxor.v $vr1, $vr1, $vr3\n\t"
"vst $vr0, %[dst], 0\n\t"
"vst $vr1, %[dst], 16\n\t"
: : [dst] "r"(p1), [src] "r"(p2)
: "memory"
);
p1 += 32;
p2 += 32;
} while (--lines > 0);
}
static void lsx_32b_glue(void * __restrict a, const void * __restrict b, size_t len)
{
xor_lsx_32b(len, a, b);
}
static void
xor_lsx_64b(unsigned long bytes, void * __restrict p1, const void * __restrict p2)
{
long lines = bytes / 64;
do {
asm volatile (
"vld $vr0, %[dst], 0\n\t"
"vld $vr1, %[dst], 16\n\t"
"vld $vr2, %[dst], 32\n\t"
"vld $vr3, %[dst], 48\n\t"
"vld $vr4, %[src], 0\n\t"
"vld $vr5, %[src], 16\n\t"
"vld $vr6, %[src], 32\n\t"
"vld $vr7, %[src], 48\n\t"
"vxor.v $vr0, $vr0, $vr4\n\t"
"vxor.v $vr1, $vr1, $vr5\n\t"
"vxor.v $vr2, $vr2, $vr6\n\t"
"vxor.v $vr3, $vr3, $vr7\n\t"
"vst $vr0, %[dst], 0\n\t"
"vst $vr1, %[dst], 16\n\t"
"vst $vr2, %[dst], 32\n\t"
"vst $vr3, %[dst], 48\n\t"
: : [dst] "r"(p1), [src] "r"(p2)
: "memory"
);
p1 += 64;
p2 += 64;
} while (--lines > 0);
}
static void lsx_64b_glue(void * __restrict a, const void * __restrict b, size_t len)
{
xor_lsx_64b(len, a, b);
}
static void
xor_lsx_128b(unsigned long bytes, void * __restrict p1, const void * __restrict p2)
{
long lines = bytes / 128;
do {
asm volatile (
"vld $vr0, %[dst], 0\n\t"
"vld $vr1, %[dst], 16\n\t"
"vld $vr2, %[dst], 32\n\t"
"vld $vr3, %[dst], 48\n\t"
"vld $vr4, %[dst], 64\n\t"
"vld $vr5, %[dst], 80\n\t"
"vld $vr6, %[dst], 96\n\t"
"vld $vr7, %[dst], 112\n\t"
"vld $vr8, %[src], 0\n\t"
"vld $vr9, %[src], 16\n\t"
"vld $vr10, %[src], 32\n\t"
"vld $vr11, %[src], 48\n\t"
"vld $vr12, %[src], 64\n\t"
"vld $vr13, %[src], 80\n\t"
"vld $vr14, %[src], 96\n\t"
"vld $vr15, %[src], 112\n\t"
"vxor.v $vr0, $vr0, $vr8\n\t"
"vxor.v $vr1, $vr1, $vr9\n\t"
"vxor.v $vr2, $vr2, $vr10\n\t"
"vxor.v $vr3, $vr3, $vr11\n\t"
"vxor.v $vr4, $vr4, $vr12\n\t"
"vxor.v $vr5, $vr5, $vr13\n\t"
"vxor.v $vr6, $vr6, $vr14\n\t"
"vxor.v $vr7, $vr7, $vr15\n\t"
"vst $vr0, %[dst], 0\n\t"
"vst $vr1, %[dst], 16\n\t"
"vst $vr2, %[dst], 32\n\t"
"vst $vr3, %[dst], 48\n\t"
"vst $vr4, %[dst], 64\n\t"
"vst $vr5, %[dst], 80\n\t"
"vst $vr6, %[dst], 96\n\t"
"vst $vr7, %[dst], 112\n\t"
: : [dst] "r"(p1), [src] "r"(p2)
: "memory"
);
p1 += 128;
p2 += 128;
} while (--lines > 0);
}
static void lsx_128b_glue(void * __restrict a, const void * __restrict b, size_t len)
{
xor_lsx_128b(len, a, b);
}
//
// LASX
//
static void
xor_lasx_32b(unsigned long bytes, void * __restrict p1, const void * __restrict p2)
{
long lines = bytes / 32;
do {
asm volatile (
"xvld $xr0, %[dst], 0\n\t"
"xvld $xr1, %[src], 0\n\t"
"xvxor.v $xr0, $xr0, $xr1\n\t"
"xvst $xr0, %[dst], 0\n\t"
: : [dst] "r"(p1), [src] "r"(p2)
: "memory"
);
p1 += 32;
p2 += 32;
} while (--lines > 0);
}
static void lasx_32b_glue(void * __restrict a, const void * __restrict b, size_t len)
{
xor_lasx_32b(len, a, b);
}
static void
xor_lasx_64b(unsigned long bytes, void * __restrict p1, const void * __restrict p2)
{
long lines = bytes / 64;
do {
asm volatile (
"xvld $xr0, %[dst], 0\n\t"
"xvld $xr1, %[dst], 32\n\t"
"xvld $xr2, %[src], 0\n\t"
"xvld $xr3, %[src], 32\n\t"
"xvxor.v $xr0, $xr0, $xr2\n\t"
"xvxor.v $xr1, $xr1, $xr3\n\t"
"xvst $xr0, %[dst], 0\n\t"
"xvst $xr1, %[dst], 32\n\t"
: : [dst] "r"(p1), [src] "r"(p2)
: "memory"
);
p1 += 64;
p2 += 64;
} while (--lines > 0);
}
static void lasx_64b_glue(void * __restrict a, const void * __restrict b, size_t len)
{
xor_lasx_64b(len, a, b);
}
static void
xor_lasx_128b(unsigned long bytes, void * __restrict p1, const void * __restrict p2)
{
long lines = bytes / 128;
do {
asm volatile (
"xvld $xr0, %[dst], 0\n\t"
"xvld $xr1, %[dst], 32\n\t"
"xvld $xr2, %[dst], 64\n\t"
"xvld $xr3, %[dst], 96\n\t"
"xvld $xr4, %[src], 0\n\t"
"xvld $xr5, %[src], 32\n\t"
"xvld $xr6, %[src], 64\n\t"
"xvld $xr7, %[src], 96\n\t"
"xvxor.v $xr0, $xr0, $xr4\n\t"
"xvxor.v $xr1, $xr1, $xr5\n\t"
"xvxor.v $xr2, $xr2, $xr6\n\t"
"xvxor.v $xr3, $xr3, $xr7\n\t"
"xvst $xr0, %[dst], 0\n\t"
"xvst $xr1, %[dst], 32\n\t"
"xvst $xr2, %[dst], 64\n\t"
"xvst $xr3, %[dst], 96\n\t"
: : [dst] "r"(p1), [src] "r"(p2)
: "memory"
);
p1 += 128;
p2 += 128;
} while (--lines > 0);
}
static void lasx_128b_glue(void * __restrict a, const void * __restrict b, size_t len)
{
xor_lasx_128b(len, a, b);
}
//
// helpers
//
static void must_fill_randomness(void *buf, size_t len)
{
ssize_t ret;
void *p = buf;
while (len) {
ret = getrandom(p, len, 0);
if (ret < 0)
abort();
p += ret;
len -= ret;
}
}
static struct timespec diff_timespec(
const struct timespec *time1,
const struct timespec *time0)
{
struct timespec diff = {
.tv_sec = time1->tv_sec - time0->tv_sec,
.tv_nsec = time1->tv_nsec - time0->tv_nsec
};
if (diff.tv_nsec < 0) {
diff.tv_nsec += 1000000000; // nsec/sec
diff.tv_sec--;
}
return diff;
}
static struct timespec div_timespec(struct timespec x, int denom)
{
// assume the value is not very large
long s = x.tv_sec * 1000000000 + x.tv_nsec;
s /= denom;
struct timespec ret = {
.tv_sec = s / 1000000000,
.tv_nsec = s % 1000000000,
};
return ret;
}
static double get_throughput(int size, struct timespec elapsed, int times)
{
double secs = (double)(elapsed.tv_sec * 1000000000l + (long)(elapsed.tv_nsec)) / 1e9;
double total_size = (double)((long)size * (long)times);
return total_size / secs;
}
static int run_order(int order, const char *desc, xor_impl_t fn)
{
void *a, *b, *ref;
int size = 1 << order;
struct timespec start, end, elapsed, pass_time;
int i, ret;
if (!(a = malloc(size)))
abort();
if (!(b = malloc(size)))
abort();
if (!(ref = malloc(size)))
abort();
must_fill_randomness(a, size);
must_fill_randomness(b, size);
memcpy(ref, a, size);
reference_xor(ref, b, size);
{
if (clock_gettime(CLOCK_THREAD_CPUTIME_ID, &start))
abort();
for (i = 0; i < TIMES; i++)
fn(a, b, size);
if (clock_gettime(CLOCK_THREAD_CPUTIME_ID, &end))
abort();
}
elapsed = diff_timespec(&end, &start);
pass_time = div_timespec(elapsed, TIMES);
ret = memcmp(a, ref, size) != 0;
printf(
"%-10s(size=%-7d) %s %d times: %ld.%09ld s total, %ld.%09ld s per pass, %.3lf MiB/s\n",
desc,
size,
ret ? "failed" : "passed",
TIMES,
elapsed.tv_sec,
elapsed.tv_nsec,
pass_time.tv_sec,
pass_time.tv_nsec,
get_throughput(size, elapsed, TIMES) / 1048576.0
);
free(ref);
free(b);
free(a);
return ret;
}
static int try_all_orders(const char *desc, xor_impl_t fn)
{
int order, ret = 0;
for (order = DATA_SIZE_MIN_ORDER; order <= DATA_SIZE_MAX_ORDER; order++)
ret |= run_order(order, desc, fn);
return ret;
}
int main(int argc, const char *argv[])
{
int ret = 0;
ret |= try_all_orders("ref", reference_xor);
ret |= try_all_orders("lsx_32b", lsx_32b_glue);
ret |= try_all_orders("lsx_64b", lsx_64b_glue);
ret |= try_all_orders("lsx_128b", lsx_128b_glue);
ret |= try_all_orders("lasx_32b", lasx_32b_glue);
ret |= try_all_orders("lasx_64b", lasx_64b_glue);
ret |= try_all_orders("lasx_128b", lasx_128b_glue);
return ret;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment