Skip to content

Instantly share code, notes, and snippets.

@juliuskoskela
Last active July 27, 2021 12:15
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save juliuskoskela/1a5ce40794325b2dac4a2a072ce5d404 to your computer and use it in GitHub Desktop.
Save juliuskoskela/1a5ce40794325b2dac4a2a072ce5d404 to your computer and use it in GitHub Desktop.
#include <immintrin.h>
#include <time.h>
#include <stdint.h>
#include <string.h>
#include <stdlib.h>
#include <stdio.h>
// Typedef of 16 x doubles alternatively accessable as 4 x (4 x 4)
// row vectors.
typedef union u_m4d
{
__m256d m256d[4];
double d[4][4];
} t_m4d;
// Typedef of 4 x doubles alternatively accessable as a 4 x 4
// row vector.
typedef union u_v4d
{
__m256d m256d;
double d[4];
} t_v4d;
// Example matrices.
//
// left[0] (1 2 3 4) | right[0] (4 2 3 4)
// left[1] (2 2 3 4) | right[1] (3 2 3 4)
// left[2] (3 2 3 4) | right[2] (2 2 3 4)
// left[3] (4 2 3 4) | right[3] (1 2 3 4)
__attribute__((noinline))
void m4x4d_avx_mul(
t_m4d *restrict dst,
const t_m4d *restrict left,
const t_m4d *restrict right)
{
__m256d ymm0;
__m256d ymm1;
__m256d ymm2;
__m256d ymm3;
// Fill registers ymm0 -> ymm3 with a single value
// from the i:th column of the left
// hand matrix.
//
// left[0] (1 2 3 4) -> ymm0 (1 1 1 1)
// left[1] (2 2 3 4) -> ymm1 (2 2 2 2)
// left[2] (3 2 3 4) -> ymm2 (3 3 3 3)
// left[3] (4 2 3 4) -> ymm3 (4 4 4 4)
ymm0 = _mm256_broadcast_sd(&left->d[0][0]);
ymm1 = _mm256_broadcast_sd(&left->d[0][1]);
ymm2 = _mm256_broadcast_sd(&left->d[0][2]);
ymm3 = _mm256_broadcast_sd(&left->d[0][3]);
// Multiply vector at register ymm0 with right row[0]
//
// 1 1 1 1 <- ymm0
// *
// 4 2 3 4 <- right[0]
// ----------
// 4 2 3 4 <- ymm0
ymm0 = _mm256_mul_pd(ymm0, right->m256d[0]);
// Multiply vector at register ymm1 with right hand
// row[1] and add at each multiply add the corresponding
// value at ymm0 tp the result.
//
// 2 2 2 2 <- ymm1
// *
// 3 2 3 4 <- right[1]
// +
// 4 2 3 4 <- ymm0
// ----------
// 10 6 9 12 <- ymm0
ymm0 = _mm256_fmadd_pd(ymm1, right->m256d[1], ymm0);
// We repeat for ymm2 -> ymm3.
//
// 3 3 3 3 <- ymm2
// *
// 2 2 3 4 <- right[2]
// ----------
// 6 6 9 12 <- ymm2
//
// 2 2 2 2 <- ymm3
// *
// 3 2 3 4 <- right[3]
// +
// 6 6 9 12 <- ymm2
// ----------
// 10 14 21 28 <- ymm2
ymm2 = _mm256_mul_pd(ymm2, right->m256d[2]);
ymm2 = _mm256_fmadd_pd(ymm3, right->m256d[3], ymm2);
// Sum accumulated vectors at ymm0 and ymm2.
//
// 10 6 9 12 <- ymm0
// +
// 10 14 21 28 <- ymm2
// ----------
// 20 20 30 40 <- dst[0] First row!
dst->m256d[0] = _mm256_add_pd(ymm0, ymm2);
// Calculate dst[1]
ymm0 = _mm256_broadcast_sd(&left->d[1][0]);
ymm1 = _mm256_broadcast_sd(&left->d[1][1]);
ymm2 = _mm256_broadcast_sd(&left->d[1][2]);
ymm3 = _mm256_broadcast_sd(&left->d[1][3]);
ymm0 = _mm256_mul_pd(ymm0, right->m256d[0]);
ymm0 = _mm256_fmadd_pd(ymm1, right->m256d[1], ymm0);
ymm2 = _mm256_mul_pd(ymm2, right->m256d[2]);
ymm2 = _mm256_fmadd_pd(ymm3, right->m256d[3], ymm2);
dst->m256d[1] = _mm256_add_pd(ymm0, ymm2);
// Calculate dst[2]
ymm0 = _mm256_broadcast_sd(&left->d[2][0]);
ymm1 = _mm256_broadcast_sd(&left->d[2][1]);
ymm2 = _mm256_broadcast_sd(&left->d[2][2]);
ymm3 = _mm256_broadcast_sd(&left->d[2][3]);
ymm0 = _mm256_mul_pd(ymm0, right->m256d[0]);
ymm0 = _mm256_fmadd_pd(ymm1, right->m256d[1], ymm0);
ymm2 = _mm256_mul_pd(ymm2, right->m256d[2]);
ymm2 = _mm256_fmadd_pd(ymm3, right->m256d[3], ymm2);
dst->m256d[2] = _mm256_add_pd(ymm0, ymm2);
// Calculate dst[3]
ymm0 = _mm256_broadcast_sd(&left->d[3][0]);
ymm1 = _mm256_broadcast_sd(&left->d[3][1]);
ymm2 = _mm256_broadcast_sd(&left->d[3][2]);
ymm3 = _mm256_broadcast_sd(&left->d[3][3]);
ymm0 = _mm256_mul_pd(ymm0, right->m256d[0]);
ymm0 = _mm256_fmadd_pd(ymm1, right->m256d[1], ymm0);
ymm2 = _mm256_mul_pd(ymm2, right->m256d[2]);
ymm2 = _mm256_fmadd_pd(ymm3, right->m256d[3], ymm2);
dst->m256d[3] = _mm256_add_pd(ymm0, ymm2);
}
// Resulting matrix:
//
// 20 20 30 40
// 24 22 33 44
// 28 24 36 48
// 32 26 39 52
// Same as above but no fmadd instead two multiplys and an add. Seems to perform the best.
// We use attribute noinline so that compiler won't inline everything to main loop.
__attribute__((noinline))
void m4x4d_avx_mul2(
t_m4d *restrict dst,
const t_m4d *restrict left,
const t_m4d *restrict right)
{
__m256d ymm[4];
for (int i = 0; i < 4; i++)
{
ymm[0] = _mm256_broadcast_sd(&left->d[i][0]);
ymm[1] = _mm256_broadcast_sd(&left->d[i][1]);
ymm[2] = _mm256_broadcast_sd(&left->d[i][2]);
ymm[3] = _mm256_broadcast_sd(&left->d[i][3]);
ymm[0] = _mm256_mul_pd(ymm[0], right->m256d[0]);
ymm[1] = _mm256_mul_pd(ymm[1], right->m256d[1]);
ymm[0] = _mm256_add_pd(ymm[0], ymm[1]);
ymm[2] = _mm256_mul_pd(ymm[2], right->m256d[2]);
ymm[3] = _mm256_mul_pd(ymm[3], right->m256d[3]);
ymm[2] = _mm256_add_pd(ymm[2], ymm[3]);
dst->m256d[i] = _mm256_add_pd(ymm[0], ymm[2]);
}
}
// Basic 4 x 4 matrix multiplication.
__attribute__((noinline))
void m4x4d_mul(double d[4][4], double l[4][4], double r[4][4])
{
d[0][0] = l[0][0] * r[0][0] + l[0][1] * r[1][0] + l[0][2] * r[2][0] + l[0][3] * r[3][0];
d[0][1] = l[0][0] * r[0][1] + l[0][1] * r[1][1] + l[0][2] * r[2][1] + l[0][3] * r[3][1];
d[0][2] = l[0][0] * r[0][2] + l[0][1] * r[1][2] + l[0][2] * r[2][2] + l[0][3] * r[3][2];
d[0][3] = l[0][0] * r[0][3] + l[0][1] * r[1][3] + l[0][2] * r[2][3] + l[0][3] * r[3][3];
d[1][0] = l[1][0] * r[0][0] + l[1][1] * r[1][0] + l[1][2] * r[2][0] + l[1][3] * r[3][0];
d[1][1] = l[1][0] * r[0][1] + l[1][1] * r[1][1] + l[1][2] * r[2][1] + l[1][3] * r[3][1];
d[1][2] = l[1][0] * r[0][2] + l[1][1] * r[1][2] + l[1][2] * r[2][2] + l[1][3] * r[3][2];
d[1][3] = l[1][0] * r[0][3] + l[1][1] * r[1][3] + l[1][2] * r[2][3] + l[1][3] * r[3][3];
d[2][0] = l[2][0] * r[0][0] + l[2][1] * r[1][0] + l[2][2] * r[2][0] + l[2][3] * r[3][0];
d[2][1] = l[2][0] * r[0][1] + l[2][1] * r[1][1] + l[2][2] * r[2][1] + l[2][3] * r[3][1];
d[2][2] = l[2][0] * r[0][2] + l[2][1] * r[1][2] + l[2][2] * r[2][2] + l[2][3] * r[3][2];
d[2][3] = l[2][0] * r[0][3] + l[2][1] * r[1][3] + l[2][2] * r[2][3] + l[2][3] * r[3][3];
d[3][0] = l[3][0] * r[0][0] + l[3][1] * r[1][0] + l[3][2] * r[2][0] + l[3][3] * r[3][0];
d[3][1] = l[3][0] * r[0][1] + l[3][1] * r[1][1] + l[3][2] * r[2][1] + l[3][3] * r[3][1];
d[3][2] = l[3][0] * r[0][2] + l[3][1] * r[1][2] + l[3][2] * r[2][2] + l[3][3] * r[3][2];
d[3][3] = l[3][0] * r[0][3] + l[3][1] * r[1][3] + l[3][2] * r[2][3] + l[3][3] * r[3][3];
};
///////////////////////////////////////////////////////////////////////////////
//
// Main and utils for testing.
t_v4d v4d_set(double n0, double n1, double n2, double n3)
{
t_v4d v;
v.d[0] = n0;
v.d[1] = n1;
v.d[2] = n2;
v.d[3] = n3;
return (v);
}
t_m4d m4d_set(t_v4d v0, t_v4d v1, t_v4d v2, t_v4d v3)
{
t_m4d m;
m.m256d[0] = v0.m256d;
m.m256d[1] = v1.m256d;
m.m256d[2] = v2.m256d;
m.m256d[3] = v3.m256d;
return (m);
}
int main(int argc, char **argv)
{
t_m4d left;
t_m4d right;
t_m4d res;
t_m4d ctr;
if (argc != 2)
return (printf("usage: avx4x4 [iters]"));
left = m4d_set(
v4d_set(1, 2, 3, 4),
v4d_set(2, 2, 3, 4),
v4d_set(3, 2, 3, 4),
v4d_set(4, 2, 3, 4));
right = m4d_set(
v4d_set(4, 2, 3, 4),
v4d_set(3, 2, 3, 4),
v4d_set(2, 2, 3, 4),
v4d_set(1, 2, 3, 4));
size_t iters;
clock_t begin;
clock_t end;
double time_spent;
// Test 1
m4x4d_mul(ctr.d, left.d, right.d);
iters = atoi(argv[1]);
begin = clock();
for (size_t i = 0; i < iters; i++)
{
m4x4d_mul(res.d, left.d, right.d);
// To prevent loop unrolling with optimisation flags.
__asm__ volatile ("" : "+g" (i));
}
end = clock();
time_spent = (double)(end - begin) / CLOCKS_PER_SEC;
printf("\nNORMAL\n\ntime: %lf\n", time_spent);
// Test 2
m4x4d_avx_mul(&ctr, &left, &right);
iters = atoi(argv[1]);
begin = clock();
for (size_t i = 0; i < iters; i++)
{
m4x4d_avx_mul(&res, &left, &right);
__asm__ volatile ("" : "+g" (i));
}
end = clock();
time_spent = (double)(end - begin) / CLOCKS_PER_SEC;
printf("\nAVX MUL + FMADD\n\ntime: %lf\n", time_spent);
// Test 3
m4x4d_avx_mul2(&ctr, &left, &right);
iters = atoi(argv[1]);
begin = clock();
for (size_t i = 0; i < iters; i++)
{
m4x4d_avx_mul2(&res, &left, &right);
__asm__ volatile ("" : "+g" (i));
}
end = clock();
time_spent = (double)(end - begin) / CLOCKS_PER_SEC;
printf("\nAVX MUL + MUL + ADD\n\ntime: %lf\n", time_spent);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment