Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
SSE/AVX matrix multiply
#include <immintrin.h>
#include <intrin.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
union Mat44 {
float m[4][4];
__m128 row[4];
};
// reference implementation
void matmult_ref(Mat44 &out, const Mat44 &A, const Mat44 &B)
{
Mat44 t; // write to temp
for (int i=0; i < 4; i++)
for (int j=0; j < 4; j++)
t.m[i][j] = A.m[i][0]*B.m[0][j] + A.m[i][1]*B.m[1][j] + A.m[i][2]*B.m[2][j] + A.m[i][3]*B.m[3][j];
out = t;
}
// linear combination:
// a[0] * B.row[0] + a[1] * B.row[1] + a[2] * B.row[2] + a[3] * B.row[3]
static inline __m128 lincomb_SSE(const __m128 &a, const Mat44 &B)
{
__m128 result;
result = _mm_mul_ps(_mm_shuffle_ps(a, a, 0x00), B.row[0]);
result = _mm_add_ps(result, _mm_mul_ps(_mm_shuffle_ps(a, a, 0x55), B.row[1]));
result = _mm_add_ps(result, _mm_mul_ps(_mm_shuffle_ps(a, a, 0xaa), B.row[2]));
result = _mm_add_ps(result, _mm_mul_ps(_mm_shuffle_ps(a, a, 0xff), B.row[3]));
return result;
}
// this is the right approach for SSE ... SSE4.2
void matmult_SSE(Mat44 &out, const Mat44 &A, const Mat44 &B)
{
// out_ij = sum_k a_ik b_kj
// => out_0j = a_00 * b_0j + a_01 * b_1j + a_02 * b_2j + a_03 * b_3j
__m128 out0x = lincomb_SSE(A.row[0], B);
__m128 out1x = lincomb_SSE(A.row[1], B);
__m128 out2x = lincomb_SSE(A.row[2], B);
__m128 out3x = lincomb_SSE(A.row[3], B);
out.row[0] = out0x;
out.row[1] = out1x;
out.row[2] = out2x;
out.row[3] = out3x;
}
// another linear combination, using AVX instructions on XMM regs
static inline __m128 lincomb_AVX_4mem(const float *a, const Mat44 &B)
{
__m128 result;
result = _mm_mul_ps(_mm_broadcast_ss(&a[0]), B.row[0]);
result = _mm_add_ps(result, _mm_mul_ps(_mm_broadcast_ss(&a[1]), B.row[1]));
result = _mm_add_ps(result, _mm_mul_ps(_mm_broadcast_ss(&a[2]), B.row[2]));
result = _mm_add_ps(result, _mm_mul_ps(_mm_broadcast_ss(&a[3]), B.row[3]));
return result;
}
// using AVX instructions, 4-wide
// this can be better if A is in memory.
void matmult_AVX_4mem(Mat44 &out, const Mat44 &A, const Mat44 &B)
{
_mm256_zeroupper();
__m128 out0x = lincomb_AVX_4mem(A.m[0], B);
__m128 out1x = lincomb_AVX_4mem(A.m[1], B);
__m128 out2x = lincomb_AVX_4mem(A.m[2], B);
__m128 out3x = lincomb_AVX_4mem(A.m[3], B);
out.row[0] = out0x;
out.row[1] = out1x;
out.row[2] = out2x;
out.row[3] = out3x;
}
// dual linear combination using AVX instructions on YMM regs
static inline __m256 twolincomb_AVX_8(__m256 A01, const Mat44 &B)
{
__m256 result;
result = _mm256_mul_ps(_mm256_shuffle_ps(A01, A01, 0x00), _mm256_broadcast_ps(&B.row[0]));
result = _mm256_add_ps(result, _mm256_mul_ps(_mm256_shuffle_ps(A01, A01, 0x55), _mm256_broadcast_ps(&B.row[1])));
result = _mm256_add_ps(result, _mm256_mul_ps(_mm256_shuffle_ps(A01, A01, 0xaa), _mm256_broadcast_ps(&B.row[2])));
result = _mm256_add_ps(result, _mm256_mul_ps(_mm256_shuffle_ps(A01, A01, 0xff), _mm256_broadcast_ps(&B.row[3])));
return result;
}
// this should be noticeably faster with actual 256-bit wide vector units (Intel);
// not sure about double-pumped 128-bit (AMD), would need to check.
void matmult_AVX_8(Mat44 &out, const Mat44 &A, const Mat44 &B)
{
_mm256_zeroupper();
__m256 A01 = _mm256_loadu_ps(&A.m[0][0]);
__m256 A23 = _mm256_loadu_ps(&A.m[2][0]);
__m256 out01x = twolincomb_AVX_8(A01, B);
__m256 out23x = twolincomb_AVX_8(A23, B);
_mm256_storeu_ps(&out.m[0][0], out01x);
_mm256_storeu_ps(&out.m[2][0], out23x);
}
// ---- testing stuff
static float randf()
{
// assumes VC++ rand()
return (rand() - 16384.0f) / 1024.0f;
}
static void randmat(Mat44 &M)
{
for (int i=0; i < 4; i++)
for (int j=0; j < 4; j++)
M.m[i][j] = randf();
}
int the_mask = 0; // global so the compiler can't be sure what its value is for opt.
static void run_ref(Mat44 *out, const Mat44 *A, const Mat44 *B, int count)
{
for (int i=0; i < count; i++)
{
int j = i & the_mask;
matmult_ref(out[j], A[j], B[j]);
}
}
static void run_SSE(Mat44 *out, const Mat44 *A, const Mat44 *B, int count)
{
for (int i=0; i < count; i++)
{
int j = i & the_mask;
matmult_SSE(out[j], A[j], B[j]);
}
}
static void run_AVX_4mem(Mat44 *out, const Mat44 *A, const Mat44 *B, int count)
{
for (int i=0; i < count; i++)
{
int j = i & the_mask;
matmult_AVX_4mem(out[j], A[j], B[j]);
}
}
static void run_AVX_8(Mat44 *out, const Mat44 *A, const Mat44 *B, int count)
{
for (int i=0; i < count; i++)
{
int j = i & the_mask;
matmult_AVX_8(out[j], A[j], B[j]);
}
}
int main(int argc, char **argv)
{
static const struct {
const char *name;
void (*matmult)(Mat44 &out, const Mat44 &A, const Mat44 &B);
} variants[] = {
{ "ref", matmult_ref },
{ "SSE", matmult_SSE },
{ "AVX_4mem", matmult_AVX_4mem },
{ "AVX_8", matmult_AVX_8 },
};
static const int nvars = (int) (sizeof(variants) / sizeof(*variants));
srand(1234); // deterministic random tests(TM)
// correctness tests
// when compiled with /arch:SSE (or SSE2/AVX), all functions are
// supposed to return the exact same results!
for (int i=0; i < 1000000; i++)
{
Mat44 A, B, out, ref_out;
randmat(A);
randmat(B);
matmult_ref(ref_out, A, B);
for (int j=0; j < nvars; j++)
{
variants[j].matmult(out, A, B);
if (memcmp(&out, &ref_out, sizeof(out)) != 0)
{
fprintf(stderr, "%s fails test\n", variants[j].name);
exit(1);
}
}
}
printf("all ok.\n");
// perf tests
// as usual with such microbenchmarks, this isn't measuring anything
// terribly useful, but here goes.
static const struct {
const char *name;
void (*run)(Mat44 *out, const Mat44 *A, const Mat44 *B, int count);
} perf_variants[] = {
{ "ref", run_ref },
{ "SSE", run_SSE },
{ "AVX_4mem", run_AVX_4mem },
{ "AVX_8", run_AVX_8 },
};
static const int nperfvars = (int) (sizeof(perf_variants) / sizeof(*perf_variants));
/*
results on my sandy bridge laptop when compiling the code in x64
mode with VC2010 using /arch:AVX:
all ok.
ref: 59.00 cycles
SSE: 20.52 cycles
AVX_4mem: 15.64 cycles
AVX_8: 14.13 cycles
*/
Mat44 Aperf, Bperf, out;
randmat(Aperf);
randmat(Bperf);
for (int i=0; i < nvars; i++)
{
static const int nruns = 4096;
static const int muls_per_run = 4096;
unsigned long long best_time = ~0ull;
for (int run=0; run < nruns; run++)
{
unsigned long long time = __rdtsc();
perf_variants[i].run(&out, &Aperf, &Bperf, muls_per_run);
time = __rdtsc() - time;
if (time < best_time)
best_time = time;
}
double cycles_per_run = (double) best_time / (double) muls_per_run;
printf("%12s: %.2f cycles\n", perf_variants[i].name, cycles_per_run);
}
return 0;
}
@Lovejane-N
Copy link

Lovejane-N commented Oct 25, 2021

void matmult_ref(Mat44 &out, const Mat44 &A, const Mat44 &B)
{
Mat44 t; // write to temp
t.m[0][0] = A.m[0][0]*B.m[0][0] + A.m[0][1]*B.m[1][0] + A.m[0][2]*B.m[2][0] + A.m[0][3]*B.m[3][0];
t.m[0][1] = A.m[0][0]*B.m[0][1] + A.m[0][1]*B.m[1][1] + A.m[0][2]*B.m[2][1] + A.m[0][3]*B.m[3][1];
t.m[0][2] = A.m[0][0]*B.m[0][2] + A.m[0][1]*B.m[1][2] + A.m[0][2]*B.m[2][2] + A.m[0][3]*B.m[3][2];
t.m[0][3] = A.m[0][0]*B.m[0][3] + A.m[0][1]*B.m[1][3] + A.m[0][2]*B.m[2][3] + A.m[0][3]*B.m[3][3];

t.m[1][0] = A.m[1][0]*B.m[0][0] + A.m[1][1]*B.m[1][0] + A.m[1][2]*B.m[2][0] + A.m[1][3]*B.m[3][0];
t.m[1][1] = A.m[1][0]*B.m[0][1] + A.m[1][1]*B.m[1][1] + A.m[1][2]*B.m[2][1] + A.m[1][3]*B.m[3][1];
t.m[1][2] = A.m[1][0]*B.m[0][2] + A.m[1][1]*B.m[1][2] + A.m[1][2]*B.m[2][2] + A.m[1][3]*B.m[3][2];
t.m[1][3] = A.m[1][0]*B.m[0][3] + A.m[1][1]*B.m[1][3] + A.m[1][2]*B.m[2][3] + A.m[1][3]*B.m[3][3];

t.m[2][0] = A.m[2][0]*B.m[0][0] + A.m[2][1]*B.m[1][0] + A.m[2][2]*B.m[2][0] + A.m[2][3]*B.m[3][0];
t.m[2][1] = A.m[2][0]*B.m[0][1] + A.m[2][1]*B.m[1][1] + A.m[2][2]*B.m[2][1] + A.m[2][3]*B.m[3][1];
t.m[2][2] = A.m[2][0]*B.m[0][2] + A.m[2][1]*B.m[1][2] + A.m[2][2]*B.m[2][2] + A.m[2][3]*B.m[3][2];
t.m[2][3] = A.m[2][0]*B.m[0][3] + A.m[2][1]*B.m[1][3] + A.m[2][2]*B.m[2][3] + A.m[2][3]*B.m[3][3];

t.m[3][0] = A.m[3][0]*B.m[0][0] + A.m[3][1]*B.m[1][0] + A.m[3][2]*B.m[2][0] + A.m[3][3]*B.m[3][0];
t.m[3][1] = A.m[3][0]*B.m[0][1] + A.m[3][1]*B.m[1][1] + A.m[3][2]*B.m[2][1] + A.m[3][3]*B.m[3][1];
t.m[3][2] = A.m[3][0]*B.m[0][2] + A.m[3][1]*B.m[1][2] + A.m[3][2]*B.m[2][2] + A.m[3][3]*B.m[3][2];
t.m[3][3] = A.m[3][0]*B.m[0][3] + A.m[3][1]*B.m[1][3] + A.m[3][2]*B.m[2][3] + A.m[3][3]*B.m[3][3];

out = t;

}

Using for-loop will slow down the code significantly, that's why O(1,2,3) options can make the code faster.

@Catoverflow
Copy link

Catoverflow commented Jun 2, 2022

g++ (GCC) 12.1.0, Arch linux 5.18.1, Intel i7-9750H
g++ -mavx -O2

all ok.
         ref: 185.01 cycles
         SSE: 120.06 cycles
    AVX_4mem: 118.66 cycles
       AVX_8: 72.41 cycles

g++ -mavx -O3

all ok.
         ref: 7.39 cycles
         SSE: 12.07 cycles
    AVX_4mem: 12.44 cycles
       AVX_8: 8.17 cycles

Also, intrin.h is not included


I checked asembly code in compiler explorer, with -O2 option the ref function is compiled into AVX, and the inner loop (line 17) is unfolded. By -O3 the outer loop (line 16) is also unfolded, exactly as @Lovejane-N 's comment above. Interestingly, with -O3 both native ref and AVX-8 have no loop, buf the former one runs faster.


And for @rygorous 's questions, I'm doing Computer Architecture experiment, and this code is a great help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment