Last active
August 20, 2022 11:59
-
-
Save Catoverflow/e2c0976e7b6f477e722c738479f583f0 to your computer and use it in GitHub Desktop.
Accearlate matrix multiplication by AVX and tiling
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Credit: AVX part refered to https://chryswoods.com/vector_c++/immintrin.html | |
* and https://gist.github.com/rygorous/4172889 | |
* Documents at https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html | |
*/ | |
#include <immintrin.h> | |
#include <iostream> | |
#include <stdlib.h> | |
#include <string.h> //memcmp | |
#include <time.h> | |
#define N (1 << 8) | |
#define AVX_N (1 << 5) | |
#define TILE_BLOCK_SIZE (1 << 5) | |
// if MAX_FLOAT is large the result's precision will be limited | |
#define MAX_FLOAT 1 | |
// use union to prevent unnecessary copy | |
template <unsigned T> | |
union Matrix | |
{ | |
float f[T][T]; | |
__m256 m[T][(T >> 3)]; | |
}; | |
union Vec_8 | |
{ | |
float f[8]; | |
__m256 m; | |
}; | |
//*** tools ***// | |
inline float sum(const Vec_8 &T) | |
{ | |
float t = 0; | |
for (int i = 0; i < 8; i++) | |
t += T.f[i]; | |
return t; | |
} | |
bool verify(const Matrix<N> &A, const Matrix<N> &B) | |
{ | |
for (int i = 0; i < N; i++) | |
if (memcmp(A.f[i], B.f[i], N)) | |
return false; | |
return true; | |
} | |
void randomize(Matrix<N> &T) | |
{ | |
for (int i = 0; i < N; i++) | |
for (int j = 0; j < N; j++) | |
T.f[i][j] = rand() / (static_cast<float>(RAND_MAX / MAX_FLOAT)); | |
} | |
void transpose(const Matrix<N> &T, Matrix<N> &T_t) | |
{ | |
for (int i = 0; i < N; i++) | |
for (int j = 0; j < N; j++) | |
T_t.f[i][j] = T.f[j][i]; | |
} | |
void transpose_8(const Matrix<8> &T, Matrix<8> &T_t) | |
{ | |
// note that set ps take input backwards | |
for (int i = 0; i < 8; i++) | |
T_t.m[i][0] = _mm256_set_ps(T.f[7][i], T.f[6][i], T.f[5][i], T.f[4][i], | |
T.f[3][i], T.f[2][i], T.f[1][i], T.f[0][i]); | |
} | |
//*** baseline main part ***// | |
inline void baseline(const Matrix<N> &A, const Matrix<N> &B, Matrix<N> &C) | |
{ | |
for (int i = 0; i < N; i++) | |
for (int j = 0; j < N; j++) | |
{ | |
C.f[i][j] = 0; | |
for (int k = 0; k < N; k++) | |
C.f[i][j] += A.f[i][k] * B.f[k][j]; | |
} | |
} | |
inline void baseline_AVX(const Matrix<N> &A, const Matrix<N> &B_t, Matrix<N> &C) | |
{ | |
Vec_8 tmp_sum; | |
float zero = 0; | |
for (int i = 0; i < N; i++) | |
for (int j = 0; j < N; j++) | |
{ | |
tmp_sum.m = _mm256_broadcast_ss(&zero); | |
for (int k = 0; k < AVX_N; k++) | |
tmp_sum.m = _mm256_add_ps(tmp_sum.m, _mm256_mul_ps(A.m[i][k], B_t.m[j][k])); | |
C.f[i][j] = sum(tmp_sum); | |
} | |
} | |
inline void baseline_AVX_tiling(const Matrix<N> &A, const Matrix<N> &B_t, Matrix<N> &C) | |
{ | |
Vec_8 tmp_sum; | |
float zero = 0; | |
for (int j = 0; j < N; j += TILE_BLOCK_SIZE) | |
for (int i = 0; i < N; i++) | |
for (int j_ = 0; j_ < TILE_BLOCK_SIZE; j_++) | |
{ | |
tmp_sum.m = _mm256_broadcast_ss(&zero); | |
for (int k = 0; k < AVX_N; k++) | |
tmp_sum.m = _mm256_add_ps(tmp_sum.m, _mm256_mul_ps(A.m[i][k], B_t.m[j + j_][k])); | |
C.f[i][j + j_] = sum(tmp_sum); | |
} | |
} | |
int main() | |
{ | |
static Matrix<N> A, B, C, D, E, B_t; | |
size_t t0, t1, t2, t3; | |
srand(time(NULL)); | |
randomize(A), randomize(B); | |
t0 = __rdtsc(); | |
baseline(A, B, C); | |
t1 = __rdtsc(); | |
transpose(B, B_t); | |
baseline_AVX(A, B_t, D); | |
t2 = __rdtsc(); | |
transpose(B, B_t); | |
baseline_AVX_tiling(A, B_t, E); | |
t3 = __rdtsc(); | |
// note: this verify will usually return false(not passed) | |
// because float arithmetics are not associative, different methods return slightly different result | |
std::cout << verify(C, D) << std::endl; | |
std::cout << verify(C, E) << std::endl; | |
std::cout << "Naive:\t" << (t1 - t0) << " cycles" << std::endl; | |
std::cout << "AVX:\t" << (t2 - t1) << " cycles" << std::endl; | |
std::cout << "AVX_tiled:\t" << (t3 - t2) << " cycles" << std::endl; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment