Skip to content

Instantly share code, notes, and snippets.

@kaja47
Last active December 1, 2016 02:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kaja47/241145603bbebf8766f16da943ce2bdd to your computer and use it in GitHub Desktop.
Save kaja47/241145603bbebf8766f16da943ce2bdd to your computer and use it in GitHub Desktop.
matrix multiplication
__attribute__((always_inline)) inline float hadd(__m256 x) {
x = _mm256_hadd_ps(x, x);
x = _mm256_hadd_ps(x, x);
return ((float*)&x)[0] + ((float*)&x)[4];
}
#define PAD 16
void square_mat_mul_tiered(float *a, float *b, const int len, float *res) {
for (int i = 0; i < len*(len+PAD); i++) { res[i] = 0; }
const int tile1 = TILE1;
const int tile2 = TILE2;
const int tile3 = TILE3;
const int segment = S;
#pragma omp parallel for
for (int tilei3 = 0; tilei3 < len; tilei3 += tile3) {
const int TILEI1 = tile1*2;
__m256 sums[TILEI1*tile1];
for (int tilej3 = 0; tilej3 < len; tilej3 += tile3) {
for (int tilei2 = tilei3; tilei2 < tilei3+tile3; tilei2 += tile2) {
for (int tilej2 = tilej3; tilej2 < tilej3+tile3; tilej2 += tile2) {
for (int tilei1 = tilei2; tilei1 < tilei2+tile2; tilei1 += TILEI1) {
for (int tilej1 = tilej2; tilej1 < tilej2+tile2; tilej1 += tile1) {
__m256 zero = _mm256_set1_ps(0.0);
for (int i = 0; i < TILEI1*tile1; i++) sums[i] = zero;
for (int p = 0; p < len; p += segment) {
for (int i = tilei1; i < tilei1+TILEI1; i += 2) {
for (int j = tilej1; j < tilej1+tile1; j += 2) {
int ii = i-tilei1;
int jj = j-tilej1;
float *_a = a+(i*(len+PAD))+p;
float *_b = a+((i+1)*(len+PAD))+p;
float *_c = b+(j*(len+PAD))+p;
float *_d = b+((j+1)*(len+PAD))+p;
__m256 da = sums[ ii *tile1+jj] ;
__m256 db = sums[ ii *tile1+jj+1];
__m256 dc = sums[(ii+1)*tile1+jj] ;
__m256 dd = sums[(ii+1)*tile1+jj+1];
for (int i = 0; i < segment; i+=8) {
__m256 aa = _mm256_load_ps(_a + i);
__m256 bb = _mm256_load_ps(_b + i);
__m256 cc = _mm256_load_ps(_c + i);
__m256 dd = _mm256_load_ps(_d + i);
da = _mm256_add_ps(da, _mm256_mul_ps(aa, cc));
db = _mm256_add_ps(db, _mm256_mul_ps(aa, dd));
dc = _mm256_add_ps(dc, _mm256_mul_ps(bb, cc));
dd = _mm256_add_ps(dd, _mm256_mul_ps(bb, dd));
}
sums[ ii *tile1+jj] = da;
sums[ ii *tile1+jj+1] = db;
sums[(ii+1)*tile1+jj] = dc;
sums[(ii+1)*tile1+jj+1] = dd;
}
}
}
for (int ii = 0; ii < TILEI1; ii += 1) {
for (int jj = 0; jj < tile1; jj += 1) {
int i = ii+tilei1;
int j = jj+tilej1;
res[i*len+j] = hadd(sums[ii*tile1+jj]);
}
}
}
}
}
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment