Skip to content

Instantly share code, notes, and snippets.

@Const-me
Created March 1, 2021 06:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Const-me/14da47903393acd2c3fb92c0b2eb090a to your computer and use it in GitHub Desktop.
Save Const-me/14da47903393acd2c3fb92c0b2eb090a to your computer and use it in GitHub Desktop.
int NeonTest( const uint8_t* lhs, const uint8_t* rhs, size_t count )
{
// If the length is not multiple of 16, you gonna need more code to handle the remainder
assert( 0 == ( count % 16 ) );
const uint8_t* const lhsEnd = lhs + count;
int32x4_t acc = vdupq_n_s32( 0 );
// The threshold is power of 2, using bits test for comparison for v >= 16
const uint8x16_t thresholdBitMask = vdupq_n_u8( 0xF0 );
while( lhs < lhsEnd )
{
uint8x16_t a = vld1q_u8( lhs );
uint8x16_t b = vld1q_u8( rhs );
lhs += 16;
rhs += 16;
// Integer absolute difference
uint8x16_t absDiff = vabdq_u8( a, b );
// Compare with the threshold
uint8x16_t aboveThresholdU = vtstq_u8( absDiff, thresholdBitMask );
// The above value is either 0xFF or 0, reinterpret as signed number gets us -1 for true / 0 for false.
int8x16_t aboveThresholdS = vreinterpretq_s8_u8( aboveThresholdU )
// Long pairwise add: expand int8 into int16, add them pairwise
int16x8_t sum16 = vpaddlq_s8( aboveThresholdS );
// Long pairwise add + accumulate: expand int16 into int32, add them pairwise, and accumulate
acc = vpadalq_s16( acc, sum16 );
}
// Compute horizontal sum of 4 lanes in acc
int32x2_t res = vadd_s32( vget_low_s32( acc ), vget_high_s32( acc ) );
res = vpadd_s32( res, res );
// Comparison or test instructions return -1 instead of +1 for `true`, inverting sign before the return.
return -vget_lane_s32( res, 0 );
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment