Skip to content

Instantly share code, notes, and snippets.

@Const-me
Last active March 1, 2023 22:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Const-me/15111c4f9502b001eef31ffde4aa7770 to your computer and use it in GitHub Desktop.
Save Const-me/15111c4f9502b001eef31ffde4aa7770 to your computer and use it in GitHub Desktop.
#include <immintrin.h>
// Compute product of width*16 column major matrix by vector of length `width`,
// the result is a vector of length 16
// BTW, according to godbolt.org, gcc does better than clang for this code.
void multiplyInner_avx16( const float* mat, const float* vec, size_t width, float* rdi )
{
// Using 4 accumulators per row, 4*16=64 scalars in 8 AVX vectors
__m256 a00 = _mm256_setzero_ps();
__m256 a01 = _mm256_setzero_ps();
__m256 a10 = _mm256_setzero_ps();
__m256 a11 = _mm256_setzero_ps();
__m256 a20 = _mm256_setzero_ps();
__m256 a21 = _mm256_setzero_ps();
__m256 a30 = _mm256_setzero_ps();
__m256 a31 = _mm256_setzero_ps();
// Compute these products
constexpr size_t maskAlign4 = ~(size_t)3;
const float* const vecEndAligned = vec + ( width & maskAlign4 );
while( vec < vecEndAligned )
{
// Each iteration of this loop consumes 4 elements from the vector, and 4 columns = 64 elements from the matrix
// Broadcast 4 elements from the vector
const __m256 v4 = _mm256_broadcast_ps( ( const __m128* )vec );
vec += 4;
// Column #0
__m256 v = _mm256_permute_ps( v4, _MM_SHUFFLE( 0, 0, 0, 0 ) );
a00 = _mm256_fmadd_ps( v, _mm256_load_ps( mat ), a00 );
a01 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 8 ), a01 );
// Column #1
v = _mm256_permute_ps( v4, _MM_SHUFFLE( 1, 1, 1, 1 ) );
a10 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 16 ), a10 );
a11 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 24 ), a11 );
// Column #2
v = _mm256_permute_ps( v4, _MM_SHUFFLE( 2, 2, 2, 2 ) );
a20 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 32 ), a20 );
a21 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 40 ), a21 );
// Column #3
v = _mm256_permute_ps( v4, _MM_SHUFFLE( 3, 3, 3, 3 ) );
a30 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 48 ), a30 );
a31 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 56 ), a31 );
mat += 64;
}
// Handle the remainder
// The branches are predictable, same outcome every time this function is called
const size_t rem = width % 4;
if( rem == 1 )
{
// Column #0
const __m256 v = _mm256_broadcast_ss( vec );
a00 = _mm256_fmadd_ps( v, _mm256_load_ps( mat ), a00 );
a01 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 8 ), a01 );
}
else if( rem > 1 )
{
// Broadcast 2 elements from the vector
const __m256 v2 = _mm256_castpd_ps( _mm256_broadcast_sd( (const double*)vec ) );
// Column #0
__m256 v = _mm256_moveldup_ps( v2 );
a00 = _mm256_fmadd_ps( v, _mm256_load_ps( mat ), a00 );
a01 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 8 ), a01 );
// Column #1
v = _mm256_movehdup_ps( v2 );
a10 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 16 ), a10 );
a11 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 24 ), a11 );
if( rem > 2 )
{
// Column #2
v = _mm256_broadcast_ss( vec + 2 );
a20 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 32 ), a20 );
a21 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 40 ), a21 );
}
}
// Reduce 64 accumulators to 32
a00 = _mm256_add_ps( a00, a20 );
a01 = _mm256_add_ps( a01, a21 );
a10 = _mm256_add_ps( a10, a30 );
a11 = _mm256_add_ps( a11, a31 );
// Reduce 32 accumulators to 16
a00 = _mm256_add_ps( a00, a10 );
a01 = _mm256_add_ps( a01, a11 );
// Finally, store the products
_mm256_store_ps( rdi, a00 );
_mm256_store_ps( rdi + 8, a01 );
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment