Created
February 28, 2023 13:18
-
-
Save Const-me/2b383e9129c2f343fb874c56055d189e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <immintrin.h> | |
// Compute product of width*16 column major matrix by vector of length `width`, | |
// the result is a vector of length 16 | |
void multiplyInner_avx16( const float* mat, const float* vec, size_t width, float* rdi ) | |
{ | |
// Using 2 accumulators per row to workaround data dependency on the accumulators | |
// Initialize the accumulators | |
__m256 a00 = _mm256_setzero_ps(); | |
__m256 a01 = _mm256_setzero_ps(); | |
__m256 a10 = _mm256_setzero_ps(); | |
__m256 a11 = _mm256_setzero_ps(); | |
// Compute these products | |
constexpr size_t maskAlign2 = ~(size_t)1; | |
const float* const vecEndAligned = vec + ( width & maskAlign2 ); | |
while( vec < vecEndAligned ) | |
{ | |
// Broadcast 2 elements from the vector | |
const __m256 v2 = _mm256_castpd_ps( _mm256_broadcast_sd( (const double*)vec ) ); | |
vec += 2; | |
// First column of the two | |
__m256 v = _mm256_moveldup_ps( v2 ); | |
a00 = _mm256_fmadd_ps( v, _mm256_load_ps( mat ), a00 ); | |
a01 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 8 ), a01 ); | |
// Second column | |
v = _mm256_movehdup_ps( v2 ); | |
a10 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 16 ), a10 ); | |
a11 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 24 ), a11 ); | |
mat += 32; | |
} | |
// Handle the possible remainder | |
if( 0 != ( width & 1 ) ) | |
{ | |
const __m256 v = _mm256_broadcast_ss( vec ); | |
a00 = _mm256_fmadd_ps( v, _mm256_load_ps( mat ), a00 ); | |
a01 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 8 ), a01 ); | |
} | |
// Reduce 32 scalars to 16 | |
a00 = _mm256_add_ps( a00, a10 ); | |
a01 = _mm256_add_ps( a01, a11 ); | |
// Store the products | |
_mm256_store_ps( rdi, a00 ); | |
_mm256_store_ps( rdi + 8, a01 ); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment