Created
November 21, 2020 17:50
-
-
Save ashafq/9bb35ea63c4d1212118324622a15dfed to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <immintrin.h> | |
static void mat8x8mul(float *dst, const float *src_a, const float *src_b) | |
{ | |
// Load 8 rows of B into 8 YMM registers | |
__m256 b0 = _mm256_load_ps(src_b + (0 * 8)); | |
__m256 b1 = _mm256_load_ps(src_b + (1 * 8)); | |
__m256 b2 = _mm256_load_ps(src_b + (2 * 8)); | |
__m256 b3 = _mm256_load_ps(src_b + (3 * 8)); | |
__m256 b4 = _mm256_load_ps(src_b + (4 * 8)); | |
__m256 b5 = _mm256_load_ps(src_b + (5 * 8)); | |
__m256 b6 = _mm256_load_ps(src_b + (6 * 8)); | |
__m256 b7 = _mm256_load_ps(src_b + (7 * 8)); | |
for (size_t i = 0; i < 8; ++i) { | |
// Compute i-th row of resulting matrix | |
// Broadcast column | |
__m256 at0 = _mm256_broadcast_ss(src_a + (i * 8) + 0); | |
__m256 at1 = _mm256_broadcast_ss(src_a + (i * 8) + 1); | |
__m256 at2 = _mm256_broadcast_ss(src_a + (i * 8) + 2); | |
__m256 at3 = _mm256_broadcast_ss(src_a + (i * 8) + 3); | |
__m256 at4 = _mm256_broadcast_ss(src_a + (i * 8) + 4); | |
__m256 at5 = _mm256_broadcast_ss(src_a + (i * 8) + 5); | |
__m256 at6 = _mm256_broadcast_ss(src_a + (i * 8) + 6); | |
__m256 at7 = _mm256_broadcast_ss(src_a + (i * 8) + 7); | |
// Compute intermediate values | |
__m256 t0 = _mm256_mul_ps(b0, at0); | |
__m256 t1 = _mm256_mul_ps(b1, at1); | |
__m256 t2 = _mm256_mul_ps(b2, at2); | |
__m256 t3 = _mm256_mul_ps(b3, at3); | |
__m256 t4 = _mm256_mul_ps(b4, at4); | |
__m256 t5 = _mm256_mul_ps(b5, at5); | |
__m256 t6 = _mm256_mul_ps(b6, at6); | |
__m256 t7 = _mm256_mul_ps(b7, at7); | |
// Compute the sum | |
__m256 r0 = _mm256_add_ps(t0, t1); | |
__m256 r1 = _mm256_add_ps(t2, t3); | |
__m256 r2 = _mm256_add_ps(t4, t5); | |
__m256 r3 = _mm256_add_ps(t6, t7); | |
r0 = _mm256_add_ps(r0, r1); | |
r1 = _mm256_add_ps(r2, r3); | |
r0 = _mm256_add_ps(r0, r1); | |
// Store result | |
_mm256_store_ps(dst + (i * 8), r0); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment