Skip to content

Instantly share code, notes, and snippets.

@Const-me
Created March 22, 2023 13:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Const-me/4d30e1fc767ab314596e16e90f53b6f4 to your computer and use it in GitHub Desktop.
Save Const-me/4d30e1fc767ab314596e16e90f53b6f4 to your computer and use it in GitHub Desktop.
// ==== AVX2 decompressor for Q4_0 and Q4_1 compressed blocks ====
#include <array>
#include <immintrin.h>
#include <assert.h>
#include <float.h>
// Unpack 32 4-bit fields into 32 bytes
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
inline __m256i bytesFromNibbles( const uint8_t* rsi )
{
// Load 16 bytes from memory
__m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
// Expand bytes into uint16_t values
__m256i bytes = _mm256_cvtepu8_epi16( tmp );
// Unpack values into individual bytes
const __m256i lowMask = _mm256_set1_epi8( 0xF );
__m256i high = _mm256_andnot_si256( lowMask, bytes );
__m256i low = _mm256_and_si256( lowMask, bytes );
high = _mm256_slli_epi16( high, 4 );
bytes = _mm256_or_si256( low, high );
return bytes;
}
// Convert lower 8 lower bytes in the vector from int8_t into float lanes
inline __m256 makeFloats( __m128i bytes )
{
__m256i i32 = _mm256_cvtepi8_epi32( bytes );
return _mm256_cvtepi32_ps( i32 );
}
// Decompress Q4_0 compressed block, the block size is 32
// The block payload contains 1 reference value (the first argument), and 32 4-bit values packed into 16 bytes (second argument)
std::array<__m256, 4> decompressBlock40( const float* scaling, const uint8_t* rsi )
{
// Unpack 4-bit fields into bytes
__m256i bytes = bytesFromNibbles( rsi );
// Now we have a vector with bytes in [0..15], offset into [-8..+7]
const __m256i off = _mm256_set1_epi8( 8 );
bytes = _mm256_sub_epi8( bytes, off );
// Broadcast ref1 into AVX vector
const __m256 sv = _mm256_broadcast_ss( scaling );
// Produce the result
std::array<__m256, 4> arr;
__m128i tmp = _mm256_castsi256_si128( bytes );
arr[ 0 ] = _mm256_mul_ps( sv, makeFloats( tmp ) );
tmp = _mm_srli_si128( tmp, 8 );
arr[ 1 ] = _mm256_mul_ps( sv, makeFloats( tmp ) );
tmp = _mm256_extracti128_si256( bytes, 1 );
arr[ 2 ] = _mm256_mul_ps( sv, makeFloats( tmp ) );
tmp = _mm_srli_si128( tmp, 8 );
arr[ 3 ] = _mm256_mul_ps( sv, makeFloats( tmp ) );
return arr;
}
// Decompress Q4_1 compressed block, the block size is 32
// The block payload contains min value, scaling vector, and 32 4-bit values packed into 16 bytes
std::array<__m256, 4> decompressBlock41( const float* minValue, const float* scaling, const uint8_t* rsi )
{
// Unpack 4-bit fields into bytes
const __m256i bytes = bytesFromNibbles( rsi );
// Broadcast both floats into AVX vectors
const __m256 iv = _mm256_broadcast_ss( minValue );
const __m256 sv = _mm256_broadcast_ss( scaling );
// Produce the result
std::array<__m256, 4> arr;
__m128i tmp = _mm256_castsi256_si128( bytes );
arr[ 0 ] = _mm256_fmadd_ps( sv, makeFloats( tmp ), iv );
tmp = _mm_srli_si128( tmp, 8 );
arr[ 1 ] = _mm256_fmadd_ps( sv, makeFloats( tmp ), iv );
tmp = _mm256_extracti128_si256( bytes, 1 );
arr[ 2 ] = _mm256_fmadd_ps( sv, makeFloats( tmp ), iv );
tmp = _mm_srli_si128( tmp, 8 );
arr[ 3 ] = _mm256_fmadd_ps( sv, makeFloats( tmp ), iv );
return arr;
}
// Compute dot product of two vectors, both compressed into a sequence of Q4_0 blocks
float dotProductCompressed40( size_t len, const uint8_t* x, const uint8_t* y )
{
assert( 0 == ( len % 32 ) );
const size_t countBlocks = len / 32;
// Prepare the source pointers
const float* scalesX = (const float*)x;
const float* scalesY = (const float*)y;
const float* const sxEnd = scalesX + countBlocks;
const uint8_t* bytesX = (const uint8_t*)( scalesX + countBlocks );
const uint8_t* bytesY = (const uint8_t*)( scalesY + countBlocks );
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
// Main loop
while( scalesX < sxEnd )
{
// Compute combined scale for the block
const __m256 scale = _mm256_mul_ps( _mm256_broadcast_ss( scalesX ), _mm256_broadcast_ss( scalesY ) );
scalesX++;
scalesY++;
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
__m256i bx = bytesFromNibbles( bytesX );
__m256i by = bytesFromNibbles( bytesY );
bytesX += 16;
bytesY += 16;
// Now we have a vector with bytes in [ 0 .. 15 ] interval, and we need sum( (a-8)*(b-8) )
// The value we're after is equal to sum( a*(b-8) - 8*(b-8) )
const __m256i off = _mm256_set1_epi8( 8 );
by = _mm256_sub_epi8( by, off );
// These weird multiplication instructions compute a0*b0 + a1*b1 for uint8_t a, int8_t b
__m256i p1 = _mm256_maddubs_epi16( bx, by );
__m256i p2 = _mm256_maddubs_epi16( off, by );
__m256i p16 = _mm256_sub_epi16( p1, p2 );
// We have products of signed bytes, reduced pairwise to int16_t
// Reduce pairs further to int32_t
// The following preprocessor branches implement two equivalent methods of doing so
// Which way is faster, probably depends on CPU.
#if 0
__m256i i32 = _mm256_slli_epi32( p16, 16 );
// This works because maximum value of 1 product is -8^2 = +64
// int16_t lanes don't overflow even with sums of 4 of these numbers
i32 = _mm256_add_epi16( i32, p16 );
// Arithmetic shift = sign extend
i32 = _mm256_srai_epi32( i32, 16 );
#else
// Competes for the same ports as _mm256_maddubs_epi16, needs the constant vector with ones,
// and takes 3-5 cycles of latency
// However, that's 1 instruction instead of 3.
__m256i i32 = _mm256_madd_epi16( p16, _mm256_set1_epi16( 1 ) );
#endif
// Convert int32_t to float
__m256 p = _mm256_cvtepi32_ps( i32 );
// Apply the scale, and accumulate
acc = _mm256_fmadd_ps( scale, p, acc );
}
// Return horizontal sum of the acc vector
__m128 res = _mm256_extractf128_ps( acc, 1 );
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
return _mm_cvtss_f32( res );
}
inline __m128i packNibbles( __m256i bytes )
{
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
const __m256i lowByte = _mm256_set1_epi16( 0xFF );
__m256i high = _mm256_andnot_si256( lowByte, bytes );
__m256i low = _mm256_and_si256( lowByte, bytes );
high = _mm256_srli_epi16( high, 4 );
bytes = _mm256_or_si256( low, high );
// Compress uint16_t lanes into bytes
__m128i r0 = _mm256_castsi256_si128( bytes );
__m128i r1 = _mm256_extracti128_si256( bytes, 1 );
return _mm_packus_epi16( r0, r1 );
}
// Compress row into Q4_0 compressed blocks, the block size is 32
void compressRow40( uint8_t* rdi, const float* rsi, size_t length )
{
assert( 0 == ( length % 32 ) );
const size_t countBlocks = length / 32;
const float* const rsiEnd = rsi + length;
float* rdiScale = (float*)( rdi );
uint8_t* rdiBytes = (uint8_t*)( rdiScale + countBlocks );
while( rsi < rsiEnd )
{
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( rsi );
__m256 v1 = _mm256_loadu_ps( rsi + 8 );
__m256 v2 = _mm256_loadu_ps( rsi + 16 );
__m256 v3 = _mm256_loadu_ps( rsi + 24 );
rsi += 32;
// Compute max(abs(e)) for the block
const __m256 signBit = _mm256_set1_ps( -0.0f );
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
const float maxScalar = _mm_cvtss_f32( max4 );
// Quantize these floats
const float d = maxScalar / 7.0f;
*rdiScale = d;
rdiScale++;
const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
const __m256 mul = _mm256_set1_ps( id );
// Apply the multiplier
v0 = _mm256_mul_ps( v0, mul );
v1 = _mm256_mul_ps( v1, mul );
v2 = _mm256_mul_ps( v2, mul );
v3 = _mm256_mul_ps( v3, mul );
// Round to nearest integer
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
// Convert floats to integers
__m256i i0 = _mm256_cvtps_epi32( v0 );
__m256i i1 = _mm256_cvtps_epi32( v1 );
__m256i i2 = _mm256_cvtps_epi32( v2 );
__m256i i3 = _mm256_cvtps_epi32( v3 );
// Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 );
i2 = _mm256_packs_epi32( i2, i3 );
// Convert int16 to int8
i0 = _mm256_packs_epi16( i0, i2 );
// Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
const __m256i off = _mm256_set1_epi8( 8 );
i0 = _mm256_add_epi8( i0, off );
// Compress the vector into 4 bit/value
__m128i res = packNibbles( i0 );
// The AVX2 pack instructions above process 16-byte pieces independently
// For this reason, the order of the values is now wrong, the following shuffle instruction is fixing that
// vpshufb shuffles 16-bytes vectors, 3 times faster than vpermd which shuffles across the complete 32-bytes vectors
const __m128i perm = _mm_setr_epi8( 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15 );
res = _mm_shuffle_epi8( res, perm );
// Store the vector
_mm_storeu_si128( ( __m128i* )rdiBytes, res );
rdiBytes += 16;
}
}
// ==== Debug Functions ====
#include <cmath>
#include <stdio.h>
inline void storeBlock( std::array<float, 32>& arr, std::array<__m256, 4> v )
{
float* rdi = arr.data();
_mm256_storeu_ps( rdi, v[ 0 ] );
_mm256_storeu_ps( rdi + 8, v[ 1 ] );
_mm256_storeu_ps( rdi + 16, v[ 2 ] );
_mm256_storeu_ps( rdi + 24, v[ 3 ] );
}
float decompressScalar40( float scaling, uint8_t byte )
{
assert( byte <= 15 );
int8_t val = (int8_t)byte - 8;
return scaling * val;
}
float decompressScalar41( float minValue, float scaling, uint8_t byte )
{
assert( byte <= 15 );
return std::fma( scaling, (float)byte, minValue );
}
int testDecompressor()
{
const float scaling = 13;
const float min = 44;
// From random.org
const std::array<uint8_t, 16> bytes = { 188, 56, 77, 68, 113, 245, 126, 231, 143, 225, 48, 216, 191, 53, 110, 118 };
// Decompress and store these bytes in both compressed formats
std::array<float, 32> b40, b41;
storeBlock( b40, decompressBlock40( &scaling, bytes.data() ) );
storeBlock( b41, decompressBlock41( &min, &scaling, bytes.data() ) );
// Verify the data
for( size_t i = 0; i < 32; i++ )
{
uint8_t byte = bytes[ i / 2 ];
if( 0 == ( i % 2 ) )
byte &= 0xF;
else
byte = byte >> 4;
// Verify Q4_0 decompressor
float fast = b40[ i ];
float scalar = decompressScalar40( scaling, byte );
if( fast != scalar )
return 1;
// Verify Q4_1 decompressor
fast = b41[ i ];
scalar = decompressScalar41( min, scaling, byte );
if( fast != scalar )
return 1;
}
printf( "Success!\n" );
return 0;
}
struct CompressedBlock40
{
float scale;
std::array<uint8_t, 16> bytes;
operator const uint8_t* ( ) const
{
return (const uint8_t*)this;
}
operator uint8_t* ( )
{
return (uint8_t*)this;
}
};
int testDotProduct()
{
const CompressedBlock40 x
{
3.5f,
{ 188, 56, 77, 68, 113, 245, 126, 231, 143, 225, 48, 216, 191, 53, 110, 118 }
};
const CompressedBlock40 y
{
4.17f,
{ 194, 237, 156, 194, 32, 200, 60, 253, 21, 69, 120, 124, 63, 77, 150, 143 }
};
const float dotCompressed = dotProductCompressed40( 32, x, y );
std::array<float, 32> xf, yf;
storeBlock( xf, decompressBlock40( &x.scale, x.bytes.data() ) );
storeBlock( yf, decompressBlock40( &y.scale, y.bytes.data() ) );
double dotScalar = 0;
for( size_t i = 0; i < 32; i++ )
dotScalar += (double)xf[ i ] * yf[ i ];
printf( "dotProductCompressed40: %g\nScalar: %g\n", dotCompressed, dotScalar );
return 0;
}
int testCompressor()
{
const CompressedBlock40 orig
{
// We want multiplier to be power of 2 because in this test we comparing the compressed block for exact equality with memcmp()
// Scaling floats by powers of 2 is lossless, both multiplication and division
16.0f,
// Generated by random.org, and removed the zeros
{ 0x8f, 0xd1, 0x14, 0xfe, 0x3e, 0x4c, 0x3a, 0x31, 0xce, 0x15, 0x77, 0xc6, 0x43, 0x51, 0x8e, 0x71 }
};
std::array<float, 32> fp32;
storeBlock( fp32, decompressBlock40( &orig.scale, orig.bytes.data() ) );
CompressedBlock40 recompressed;
compressRow40( recompressed, fp32.data(), 32 );
const int cmp = memcmp( &orig, &recompressed, sizeof( CompressedBlock40 ) );
if( 0 == cmp )
{
printf( "Success\n" );
return 0;
}
else
{
printf( "Fail\n" );
return 1;
}
}
int main()
{
// return testDecompressor();
return testDotProduct();
// return testCompressor();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment