Last active
July 27, 2021 12:15
-
-
Save juliuskoskela/1a5ce40794325b2dac4a2a072ce5d404 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <immintrin.h> | |
#include <time.h> | |
#include <stdint.h> | |
#include <string.h> | |
#include <stdlib.h> | |
#include <stdio.h> | |
// Typedef of 16 x doubles alternatively accessable as 4 x (4 x 4) | |
// row vectors. | |
typedef union u_m4d | |
{ | |
__m256d m256d[4]; | |
double d[4][4]; | |
} t_m4d; | |
// Typedef of 4 x doubles alternatively accessable as a 4 x 4 | |
// row vector. | |
typedef union u_v4d | |
{ | |
__m256d m256d; | |
double d[4]; | |
} t_v4d; | |
// Example matrices. | |
// | |
// left[0] (1 2 3 4) | right[0] (4 2 3 4) | |
// left[1] (2 2 3 4) | right[1] (3 2 3 4) | |
// left[2] (3 2 3 4) | right[2] (2 2 3 4) | |
// left[3] (4 2 3 4) | right[3] (1 2 3 4) | |
__attribute__((noinline)) | |
void m4x4d_avx_mul( | |
t_m4d *restrict dst, | |
const t_m4d *restrict left, | |
const t_m4d *restrict right) | |
{ | |
__m256d ymm0; | |
__m256d ymm1; | |
__m256d ymm2; | |
__m256d ymm3; | |
// Fill registers ymm0 -> ymm3 with a single value | |
// from the i:th column of the left | |
// hand matrix. | |
// | |
// left[0] (1 2 3 4) -> ymm0 (1 1 1 1) | |
// left[1] (2 2 3 4) -> ymm1 (2 2 2 2) | |
// left[2] (3 2 3 4) -> ymm2 (3 3 3 3) | |
// left[3] (4 2 3 4) -> ymm3 (4 4 4 4) | |
ymm0 = _mm256_broadcast_sd(&left->d[0][0]); | |
ymm1 = _mm256_broadcast_sd(&left->d[0][1]); | |
ymm2 = _mm256_broadcast_sd(&left->d[0][2]); | |
ymm3 = _mm256_broadcast_sd(&left->d[0][3]); | |
// Multiply vector at register ymm0 with right row[0] | |
// | |
// 1 1 1 1 <- ymm0 | |
// * | |
// 4 2 3 4 <- right[0] | |
// ---------- | |
// 4 2 3 4 <- ymm0 | |
ymm0 = _mm256_mul_pd(ymm0, right->m256d[0]); | |
// Multiply vector at register ymm1 with right hand | |
// row[1] and add at each multiply add the corresponding | |
// value at ymm0 tp the result. | |
// | |
// 2 2 2 2 <- ymm1 | |
// * | |
// 3 2 3 4 <- right[1] | |
// + | |
// 4 2 3 4 <- ymm0 | |
// ---------- | |
// 10 6 9 12 <- ymm0 | |
ymm0 = _mm256_fmadd_pd(ymm1, right->m256d[1], ymm0); | |
// We repeat for ymm2 -> ymm3. | |
// | |
// 3 3 3 3 <- ymm2 | |
// * | |
// 2 2 3 4 <- right[2] | |
// ---------- | |
// 6 6 9 12 <- ymm2 | |
// | |
// 2 2 2 2 <- ymm3 | |
// * | |
// 3 2 3 4 <- right[3] | |
// + | |
// 6 6 9 12 <- ymm2 | |
// ---------- | |
// 10 14 21 28 <- ymm2 | |
ymm2 = _mm256_mul_pd(ymm2, right->m256d[2]); | |
ymm2 = _mm256_fmadd_pd(ymm3, right->m256d[3], ymm2); | |
// Sum accumulated vectors at ymm0 and ymm2. | |
// | |
// 10 6 9 12 <- ymm0 | |
// + | |
// 10 14 21 28 <- ymm2 | |
// ---------- | |
// 20 20 30 40 <- dst[0] First row! | |
dst->m256d[0] = _mm256_add_pd(ymm0, ymm2); | |
// Calculate dst[1] | |
ymm0 = _mm256_broadcast_sd(&left->d[1][0]); | |
ymm1 = _mm256_broadcast_sd(&left->d[1][1]); | |
ymm2 = _mm256_broadcast_sd(&left->d[1][2]); | |
ymm3 = _mm256_broadcast_sd(&left->d[1][3]); | |
ymm0 = _mm256_mul_pd(ymm0, right->m256d[0]); | |
ymm0 = _mm256_fmadd_pd(ymm1, right->m256d[1], ymm0); | |
ymm2 = _mm256_mul_pd(ymm2, right->m256d[2]); | |
ymm2 = _mm256_fmadd_pd(ymm3, right->m256d[3], ymm2); | |
dst->m256d[1] = _mm256_add_pd(ymm0, ymm2); | |
// Calculate dst[2] | |
ymm0 = _mm256_broadcast_sd(&left->d[2][0]); | |
ymm1 = _mm256_broadcast_sd(&left->d[2][1]); | |
ymm2 = _mm256_broadcast_sd(&left->d[2][2]); | |
ymm3 = _mm256_broadcast_sd(&left->d[2][3]); | |
ymm0 = _mm256_mul_pd(ymm0, right->m256d[0]); | |
ymm0 = _mm256_fmadd_pd(ymm1, right->m256d[1], ymm0); | |
ymm2 = _mm256_mul_pd(ymm2, right->m256d[2]); | |
ymm2 = _mm256_fmadd_pd(ymm3, right->m256d[3], ymm2); | |
dst->m256d[2] = _mm256_add_pd(ymm0, ymm2); | |
// Calculate dst[3] | |
ymm0 = _mm256_broadcast_sd(&left->d[3][0]); | |
ymm1 = _mm256_broadcast_sd(&left->d[3][1]); | |
ymm2 = _mm256_broadcast_sd(&left->d[3][2]); | |
ymm3 = _mm256_broadcast_sd(&left->d[3][3]); | |
ymm0 = _mm256_mul_pd(ymm0, right->m256d[0]); | |
ymm0 = _mm256_fmadd_pd(ymm1, right->m256d[1], ymm0); | |
ymm2 = _mm256_mul_pd(ymm2, right->m256d[2]); | |
ymm2 = _mm256_fmadd_pd(ymm3, right->m256d[3], ymm2); | |
dst->m256d[3] = _mm256_add_pd(ymm0, ymm2); | |
} | |
// Resulting matrix: | |
// | |
// 20 20 30 40 | |
// 24 22 33 44 | |
// 28 24 36 48 | |
// 32 26 39 52 | |
// Same as above but no fmadd instead two multiplys and an add. Seems to perform the best. | |
// We use attribute noinline so that compiler won't inline everything to main loop. | |
__attribute__((noinline)) | |
void m4x4d_avx_mul2( | |
t_m4d *restrict dst, | |
const t_m4d *restrict left, | |
const t_m4d *restrict right) | |
{ | |
__m256d ymm[4]; | |
for (int i = 0; i < 4; i++) | |
{ | |
ymm[0] = _mm256_broadcast_sd(&left->d[i][0]); | |
ymm[1] = _mm256_broadcast_sd(&left->d[i][1]); | |
ymm[2] = _mm256_broadcast_sd(&left->d[i][2]); | |
ymm[3] = _mm256_broadcast_sd(&left->d[i][3]); | |
ymm[0] = _mm256_mul_pd(ymm[0], right->m256d[0]); | |
ymm[1] = _mm256_mul_pd(ymm[1], right->m256d[1]); | |
ymm[0] = _mm256_add_pd(ymm[0], ymm[1]); | |
ymm[2] = _mm256_mul_pd(ymm[2], right->m256d[2]); | |
ymm[3] = _mm256_mul_pd(ymm[3], right->m256d[3]); | |
ymm[2] = _mm256_add_pd(ymm[2], ymm[3]); | |
dst->m256d[i] = _mm256_add_pd(ymm[0], ymm[2]); | |
} | |
} | |
// Basic 4 x 4 matrix multiplication. | |
__attribute__((noinline)) | |
void m4x4d_mul(double d[4][4], double l[4][4], double r[4][4]) | |
{ | |
d[0][0] = l[0][0] * r[0][0] + l[0][1] * r[1][0] + l[0][2] * r[2][0] + l[0][3] * r[3][0]; | |
d[0][1] = l[0][0] * r[0][1] + l[0][1] * r[1][1] + l[0][2] * r[2][1] + l[0][3] * r[3][1]; | |
d[0][2] = l[0][0] * r[0][2] + l[0][1] * r[1][2] + l[0][2] * r[2][2] + l[0][3] * r[3][2]; | |
d[0][3] = l[0][0] * r[0][3] + l[0][1] * r[1][3] + l[0][2] * r[2][3] + l[0][3] * r[3][3]; | |
d[1][0] = l[1][0] * r[0][0] + l[1][1] * r[1][0] + l[1][2] * r[2][0] + l[1][3] * r[3][0]; | |
d[1][1] = l[1][0] * r[0][1] + l[1][1] * r[1][1] + l[1][2] * r[2][1] + l[1][3] * r[3][1]; | |
d[1][2] = l[1][0] * r[0][2] + l[1][1] * r[1][2] + l[1][2] * r[2][2] + l[1][3] * r[3][2]; | |
d[1][3] = l[1][0] * r[0][3] + l[1][1] * r[1][3] + l[1][2] * r[2][3] + l[1][3] * r[3][3]; | |
d[2][0] = l[2][0] * r[0][0] + l[2][1] * r[1][0] + l[2][2] * r[2][0] + l[2][3] * r[3][0]; | |
d[2][1] = l[2][0] * r[0][1] + l[2][1] * r[1][1] + l[2][2] * r[2][1] + l[2][3] * r[3][1]; | |
d[2][2] = l[2][0] * r[0][2] + l[2][1] * r[1][2] + l[2][2] * r[2][2] + l[2][3] * r[3][2]; | |
d[2][3] = l[2][0] * r[0][3] + l[2][1] * r[1][3] + l[2][2] * r[2][3] + l[2][3] * r[3][3]; | |
d[3][0] = l[3][0] * r[0][0] + l[3][1] * r[1][0] + l[3][2] * r[2][0] + l[3][3] * r[3][0]; | |
d[3][1] = l[3][0] * r[0][1] + l[3][1] * r[1][1] + l[3][2] * r[2][1] + l[3][3] * r[3][1]; | |
d[3][2] = l[3][0] * r[0][2] + l[3][1] * r[1][2] + l[3][2] * r[2][2] + l[3][3] * r[3][2]; | |
d[3][3] = l[3][0] * r[0][3] + l[3][1] * r[1][3] + l[3][2] * r[2][3] + l[3][3] * r[3][3]; | |
}; | |
/////////////////////////////////////////////////////////////////////////////// | |
// | |
// Main and utils for testing. | |
t_v4d v4d_set(double n0, double n1, double n2, double n3) | |
{ | |
t_v4d v; | |
v.d[0] = n0; | |
v.d[1] = n1; | |
v.d[2] = n2; | |
v.d[3] = n3; | |
return (v); | |
} | |
t_m4d m4d_set(t_v4d v0, t_v4d v1, t_v4d v2, t_v4d v3) | |
{ | |
t_m4d m; | |
m.m256d[0] = v0.m256d; | |
m.m256d[1] = v1.m256d; | |
m.m256d[2] = v2.m256d; | |
m.m256d[3] = v3.m256d; | |
return (m); | |
} | |
int main(int argc, char **argv) | |
{ | |
t_m4d left; | |
t_m4d right; | |
t_m4d res; | |
t_m4d ctr; | |
if (argc != 2) | |
return (printf("usage: avx4x4 [iters]")); | |
left = m4d_set( | |
v4d_set(1, 2, 3, 4), | |
v4d_set(2, 2, 3, 4), | |
v4d_set(3, 2, 3, 4), | |
v4d_set(4, 2, 3, 4)); | |
right = m4d_set( | |
v4d_set(4, 2, 3, 4), | |
v4d_set(3, 2, 3, 4), | |
v4d_set(2, 2, 3, 4), | |
v4d_set(1, 2, 3, 4)); | |
size_t iters; | |
clock_t begin; | |
clock_t end; | |
double time_spent; | |
// Test 1 | |
m4x4d_mul(ctr.d, left.d, right.d); | |
iters = atoi(argv[1]); | |
begin = clock(); | |
for (size_t i = 0; i < iters; i++) | |
{ | |
m4x4d_mul(res.d, left.d, right.d); | |
// To prevent loop unrolling with optimisation flags. | |
__asm__ volatile ("" : "+g" (i)); | |
} | |
end = clock(); | |
time_spent = (double)(end - begin) / CLOCKS_PER_SEC; | |
printf("\nNORMAL\n\ntime: %lf\n", time_spent); | |
// Test 2 | |
m4x4d_avx_mul(&ctr, &left, &right); | |
iters = atoi(argv[1]); | |
begin = clock(); | |
for (size_t i = 0; i < iters; i++) | |
{ | |
m4x4d_avx_mul(&res, &left, &right); | |
__asm__ volatile ("" : "+g" (i)); | |
} | |
end = clock(); | |
time_spent = (double)(end - begin) / CLOCKS_PER_SEC; | |
printf("\nAVX MUL + FMADD\n\ntime: %lf\n", time_spent); | |
// Test 3 | |
m4x4d_avx_mul2(&ctr, &left, &right); | |
iters = atoi(argv[1]); | |
begin = clock(); | |
for (size_t i = 0; i < iters; i++) | |
{ | |
m4x4d_avx_mul2(&res, &left, &right); | |
__asm__ volatile ("" : "+g" (i)); | |
} | |
end = clock(); | |
time_spent = (double)(end - begin) / CLOCKS_PER_SEC; | |
printf("\nAVX MUL + MUL + ADD\n\ntime: %lf\n", time_spent); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment