Skip to content

Instantly share code, notes, and snippets.

@nmoinvaz
Created April 18, 2022 04:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nmoinvaz/e59da03f7f86dcea7f7a71417fc1be56 to your computer and use it in GitHub Desktop.
Save nmoinvaz/e59da03f7f86dcea7f7a71417fc1be56 to your computer and use it in GitHub Desktop.
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdint.h>
#include <arm_neon.h>
#include <byteswap.h>
#include <stdio.h>
#include <stdint.h>
#include <arm_neon.h>
void print_uint8x16(char *name, void *a, int n) {
uint8_t *p = (uint8_t *)a;
int32_t i = 0;
printf("%s: ", name);
while (i < n) {
printf("%02x", p[i]);
if (++i % 4 == 0)
printf(" ");
}
printf("\n");
}
/* https://stackoverflow.com/questions/11870910/ */
static inline uint32_t vmovmaskq_u8(uint8x16_t input) {
const uint8_t __attribute__ ((aligned (16))) _Powers[16] =
{ 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128 };
// Set the powers of 2 (do it once for all, if applicable)
uint8x16_t Powers = vld1q_u8(_Powers);
// Compute the mask from the input
uint16x8_t pw_add_16 = vpaddlq_u8(vandq_u8(input, Powers));
pw_add_16 = vpaddq_u16(pw_add_16, pw_add_16);
uint16x8_t pw_add_final = vpaddq_u16(pw_add_16, pw_add_16);
return vgetq_lane_u16(pw_add_final, 0) + vgetq_lane_u16(pw_add_final, 1);
}
uint32_t compare256_neon_static(const uint8_t *src0, const uint8_t *src1) {
int32_t len = 0;
do {
uint8x16_t a, b, cmp;
uint32_t movemask;
a = vld1q_u8(src0);
b = vld1q_u8(src1);
cmp = vceqq_u8(a, b);
movemask = vmovmaskq_u8(cmp);
if (movemask != 0xFFFF) {
uint32_t match_byte = (uint32_t)__builtin_clz(~movemask);
return len + match_byte;
}
src0 += 16, src1 += 16, len += 16;
a = vld1q_u8(src0);
b = vld1q_u8(src1);
cmp = vceqq_u8(a, b);
movemask = vmovmaskq_u8(cmp);
if (movemask != 0xFFFF) {
uint32_t match_byte = (uint32_t)__builtin_clz(~movemask);
return len + match_byte;
}
src0 += 16, src1 += 16, len += 16;
} while (len < 256);
return 256;
}
uint32_t compare256(const uint8_t *src0, const uint8_t *src1) {
uint8x16_t a, b;
int32_t len = 0;
do {
a = vld1q_u8(src0);
b = vld1q_u8(src1);
uint8x16_t cmp = vceqq_u8(a, b);
print_uint8x16("vceqq_u8", &cmp, 16);
cmp = vmvnq_u8(cmp);
cmp = vrev32q_u8(cmp);
//cmp = vrev64q_u8(cmp);
print_uint8x16("vmvnq_u8", &cmp, 16);
uint8x16_t cnt1 = vclzq_u8(cmp);
print_uint8x16("vclzq_u8", &cnt1, 16);
uint16x8_t cnt3 = vclzq_u16(vreinterpretq_u16_u8(cmp));
print_uint8x16("vclzq_u16", &cnt3, 16);
uint32x4_t cnt4 = vclzq_u32(vreinterpretq_u32_u8(cmp));
print_uint8x16("vclzq_u32", &cnt4, 16);
//int32x4_t cnt4 = vclsq_s32(vreinterpretq_s32_u8(cmp));
//print_uint8x16("vclzq_s32", &cnt4, 16);
uint32_t idx0 = vgetq_lane_u32(cnt4, 0);
print_uint8x16("idx0", &idx0, 4);
printf("idx0: %d\n", idx0);
if (idx0 != 32)
return len + (idx0 / 8);
len += 4;
uint32_t idx1 = vgetq_lane_u32(cnt4, 1);
print_uint8x16("idx1", &idx1, 4);
printf("idx1: %d\n", idx1);
if (idx1 != 32)
return len + (idx1 / 8);
len += 4;
uint32_t idx2 = vgetq_lane_u32(cnt4, 2);
print_uint8x16("idx2", &idx2, 4);
printf("idx2: %d\n", idx2);
if (idx2 != 32)
return len + (idx2 / 8);
len += 4;
uint32_t idx3 = vgetq_lane_u32(cnt4, 3);
print_uint8x16("idx3", &idx3, 4);
printf("idx3: %d\n", idx3);
if (idx3 != 32)
return len + (idx3 / 8);
len += 4;
/*uint32_t idx0 = vgetq_lane_s32(cnt4, 0) + 1;
print_uint8x16("idx0", &idx0, 4);
printf("idx0: %d\n", idx0);
uint32_t idx1 = vgetq_lane_s32(cnt4, 1) + 1;
print_uint8x16("idx0", &idx0, 1);
printf("idx1: %d\n", idx1);
uint32_t movemask = vmovmaskq_u8(cmp);
print_uint8x16("movemask", &movemask, 4);
//uint8x16_t is_all_set = vceqq_u8(cmp, all_set); // Compare to 1
//uint64_t idx = vgetq_lane_u64(vreinterpretq_u64_u8(is_all_set), 0);
//print_uint8x16("idx", &idx, 8);
if (movemask != 0xFFFF) {
movemask = bswap_32(~movemask);
print_uint8x16("~movemask", &movemask, 4);
uint32_t clz = (uint32_t)__builtin_clz(movemask);
print_uint8x16("clz", &clz, 4);
uint32_t match_byte = clz;
printf("match_byte: %d + %d = %d\n", len, match_byte, len + match_byte);
return len + match_byte;
}
//len += 16;
/*idx = vgetq_lane_u64(vreinterpretq_u64_u8(cmp), 1);
printf("idx1: %08x %04x\n", idx, idx32);
if (idx != -1) {
uint32_t match_byte = (__builtin_clzl(idx)) / 8;
printf("match_byte1: %d + %d = %d -- %d\n", len, match_byte, len + match_byte, __builtin_clzl(idx));
return len + match_byte;
}*/
src0 += 16, src1 += 16;//, len += 16;
} while (len < 256);
return len;
}
int main() {
int32_t match_len, i;
uint8_t str1[256];
uint8_t str2[256];
memset(str1, 'a', sizeof(str1));
memset(str2, 'a', sizeof(str2));
int32_t failed = 0;
for (i = 1; i <= sizeof(str1); i++) {
if (i < sizeof(str1))
str2[i] = 0;
match_len = compare256_neon_static(str1, str2);
if (match_len == i) {
printf("ok %d == %d\n", i, match_len);
} else {
failed++;
printf("failed %d != %d\n", i, match_len);
break;
}
if (i < sizeof(str1))
str2[i] = 'a';
}
printf("failed count %d\n", failed);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment