Skip to content

Instantly share code, notes, and snippets.

@Const-me
Last active July 7, 2023 16:16
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Const-me/3ade77faad47f0fbb0538965ae7f8e04 to your computer and use it in GitHub Desktop.
Save Const-me/3ade77faad47f0fbb0538965ae7f8e04 to your computer and use it in GitHub Desktop.
#include <stdint.h>
#include <immintrin.h>
#include <intrin.h>
#include <stdio.h>
// Count of set bits in `plus` minus count of set bits in `minus`
// The result is in [ -32 .. +32 ] interval
inline int popCntDiff( uint32_t plus, uint32_t minus )
{
plus = __popcnt( plus );
minus = __popcnt( minus );
return (int)plus - (int)minus;
}
// Horizontal sum of all 4 int64_t elements in the AVX2 vector
inline int64_t hadd_epi64( __m256i v32 )
{
__m128i v = _mm256_extracti128_si256( v32, 1 );
v = _mm_add_epi64( v, _mm256_castsi256_si128( v32 ) );
const int64_t high = _mm_extract_epi64( v, 1 );
const int64_t low = _mm_cvtsi128_si64( v );
return high + low;
}
// AVX2 implementation of that algorithm
int computeWithAvx2( const char* input )
{
// Create a few constants
const __m256i zero = _mm256_setzero_si256();
const __m256i s = _mm256_set1_epi8( 's' );
const __m256i p = _mm256_set1_epi8( 'p' );
__m256i counter;
// Prologue, make sure the pointer is aligned by 32 bytes
const size_t rem = ( (size_t)input ) % 32;
if( 0 == rem )
{
// The input already aligned, initialize accumulator with 0
counter = _mm256_setzero_si256();
}
else
{
// Load aligned vector from the address before the input buffer
// Same VMEM page as the first byte of the input, no access violations
input -= rem;
const __m256i v = _mm256_load_si256( ( const __m256i* )input );
uint32_t bmpZero = (uint32_t)_mm256_movemask_epi8( _mm256_cmpeq_epi8( v, zero ) );
uint32_t bmpPlus = (uint32_t)_mm256_movemask_epi8( _mm256_cmpeq_epi8( v, s ) );
uint32_t bmpMinus = (uint32_t)_mm256_movemask_epi8( _mm256_cmpeq_epi8( v, p ) );
// Discard lower bits resulted from loading before the start of the buffer
bmpZero >>= (uint32_t)rem;
bmpPlus >>= (uint32_t)rem;
bmpMinus >>= (uint32_t)rem;
if( 0 == bmpZero )
{
// No `\0` encountered in the initial bytes of the input
// Compute initial value for the accumulator vector
__m128i iv = _mm_cvtsi64_si128( popCntDiff( bmpPlus, bmpMinus ) * 0xFF );
counter = _mm256_blend_epi32( zero, _mm256_castsi128_si256( iv ), 0b00000011 );
input += 32;
}
else
{
// The input was tiny, found `\0` already
// Clear higher bits in the two bitmaps which were after the first `\0`
const uint32_t len = _tzcnt_u32( bmpZero );
bmpPlus = _bzhi_u32( bmpPlus, len );
bmpMinus = _bzhi_u32( bmpMinus, len );
// Compute the result
return popCntDiff( bmpPlus, bmpMinus );
}
}
// The pointer is aligned by 32 bytes, which serves two purposes: we can use aligned loads,
// and most importantly loading 32 bytes guarantees to not cross page boundary.
// VMEM permissions are defined for aligned 4kb pages, we can technically load within a page without access violations,
// despite the language standard says it's UB
while( true )
{
// Load 32 bytes from the pointer
const __m256i v = _mm256_load_si256( ( const __m256i* )input );
// Compare bytes for v == '\0'
const __m256i z = _mm256_cmpeq_epi8( v, zero );
// Compare bytes for equality with these two other markers
__m256i cmpPlus = _mm256_cmpeq_epi8( v, s );
__m256i cmpMinus = _mm256_cmpeq_epi8( v, p );
const uint32_t bmpZero = (uint32_t)_mm256_movemask_epi8( z );
if( 0 != bmpZero )
{
// At least one byte of the 32 was zero
const int res = (int)( hadd_epi64( counter ) / 0xFF );
uint32_t bmpPlus = (uint32_t)_mm256_movemask_epi8( cmpPlus );
uint32_t bmpMinus = (uint32_t)_mm256_movemask_epi8( cmpMinus );
// Clear higher bits in the two bitmaps which were after the first found `\0`
const uint32_t len = _tzcnt_u32( bmpZero );
bmpPlus = _bzhi_u32( bmpPlus, len );
bmpMinus = _bzhi_u32( bmpMinus, len );
// Produce the result
return res + popCntDiff( bmpPlus, bmpMinus );
}
// Increment the source pointer
input += 32;
// Compute horizontal sum of bytes within 8-byte lanes
cmpPlus = _mm256_sad_epu8( cmpPlus, zero );
cmpMinus = _mm256_sad_epu8( cmpMinus, zero );
cmpPlus = _mm256_sub_epi64( cmpPlus, cmpMinus );
// Update the counter
counter = _mm256_add_epi64( counter, cmpPlus );
}
}
#include <vector>
#include <random>
std::vector<char> nullTerminatedRandom( size_t length )
{
std::vector<char> result;
result.resize( length );
// Deliberately seeding RNG with 0, to generate same output every time
std::mt19937 gen( 0 );
std::uniform_int_distribution<size_t> distrib( 0, 4 );
const char pattern[ 4 ]{ 's', 'p', '0', '1' };
for( char& c : result )
c = pattern[ distrib( gen ) ];
// Write terminating `\0` into the last element
result[ length - 1 ] = '\0';
return result;
}
int computeWithSwitches( const char* input )
{
int res = 0;
while( true )
{
char c = *input++;
switch( c )
{
case '\0':
return res;
case 's':
res += 1;
break;
case 'p':
res -= 1;
break;
default:
break;
}
}
}
int main()
{
constexpr bool useAvx = true;
// Using odd length slightly over 1GB, just for lulz
const size_t len = 1024 * 1024 * 1024 + 17;
const auto data = nullTerminatedRandom( len );
const char* const rsi = data.data();
// Compute the result using either of these two methods, measuring the time
const int64_t tscStart = __rdtsc();
const int sum = useAvx ? computeWithAvx2( rsi ) : computeWithSwitches( rsi );
const int64_t tscElapsed = __rdtsc() - tscStart;
printf( "%i; elapsed time: %lli\n", sum, tscElapsed );
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment