Skip to content

Instantly share code, notes, and snippets.

@alexshen
Last active August 29, 2015 13:57
Show Gist options
  • Save alexshen/9783468 to your computer and use it in GitHub Desktop.
Save alexshen/9783468 to your computer and use it in GitHub Desktop.
#include <xmmintrin.h>
#include <chrono>
#include <iostream>
#include <iomanip>
using namespace std;
using namespace chrono;
#ifdef _MSC_VER
__declspec(align(16)) struct matrix44
#else
struct alignas(16) matrix44
#endif
{
float a[4][4];
};
matrix44 std_mul(matrix44 const& m1, matrix44 const& m2)
{
matrix44 res;
for (int i = 0; i < 4; ++i)
{
for (int j = 0; j < 4; ++j)
{
res.a[i][j] = m1.a[i][0] * m2.a[0][j];
for (int k = 1; k < 4; ++k)
{
res.a[i][j] += m1.a[i][k] * m2.a[k][j];
}
}
}
return res;
}
matrix44 unroll_mul(matrix44 const& m1, matrix44 const& m2)
{
matrix44 res;
for (int i = 0; i < 4; ++i)
{
for (int j = 0; j < 4; ++j)
{
res.a[i][j] = m1.a[i][0] * m2.a[0][j];
res.a[i][j] += m1.a[i][1] * m2.a[1][j];
res.a[i][j] += m1.a[i][2] * m2.a[2][j];
res.a[i][j] += m1.a[i][3] * m2.a[3][j];
}
}
return res;
}
matrix44 sse_mul(matrix44 const& m1, matrix44 const& m2)
{
matrix44 res;
__m128 a, b, r;
for (int i = 0; i < 4; ++i)
{
a = _mm_set1_ps(m1.a[i][0]);
b = _mm_load_ps(m2.a[0]);
r = _mm_mul_ps(a, b);
for (int j = 1; j < 4; ++j)
{
a = _mm_set1_ps(m1.a[i][j]);
b = _mm_load_ps(m2.a[j]);
r = _mm_add_ps(_mm_mul_ps(a, b), r);
}
_mm_store_ps(res.a[i], r);
}
return res;
}
struct timer
{
char const* name;
high_resolution_clock::time_point start;
timer(char const* p)
: name{ p }
, start{ high_resolution_clock::now() }
{
}
~timer()
{
auto duration = high_resolution_clock::now() - start;
cout << setw(10) << name << ": " << duration_cast<nanoseconds>(duration).count() << endl;
}
};
template<typename F>
void time_it(char const* name, F f, int count)
{
matrix44 res;
{
timer t{ name };
matrix44 a =
{
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1
};
for (int i = 0; i < count; ++i)
{
res = f(a, a);
}
}
cout << res.a[0][0] << endl;
}
int main()
{
int const count = 1000000;
for (int i = 0; i < 3; ++i)
{
time_it("std_mul", std_mul, count);
time_it("unroll_mul", unroll_mul, count);
time_it("sse_mul", sse_mul, count);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment