Skip to content

Instantly share code, notes, and snippets.

@kevans91
Created May 11, 2016 13:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kevans91/d7213535353b5249705c6fddf199e67b to your computer and use it in GitHub Desktop.
Save kevans91/d7213535353b5249705c6fddf199e67b to your computer and use it in GitHub Desktop.
const char* dgemm_desc = "Simple blocked dgemm.";
#include "immintrin.h"
#include "stdio.h"
#include "string.h"
#include "stdlib.h"
#define min(a,b) (((a)<(b))?(a):(b))
#define NEXT_CIJ1() do {\
cij += *Ai * *Bj;\
} while(0)
#define NEXT_CIJ2() do {\
cij += *Ai * *Bj\
+ *(Ai + lda) * *(Bj + 1);\
} while(0)
#define NEXT_CIJ3() do {\
cij += *Ai * *Bj\
+ *(Ai + lda) * *(Bj + 1)\
+ *(Ai + lda + lda) * *(Bj + 2);\
} while(0)
#define NEXT_CIJ4() do {\
cij += *Ai * *Bj\
+ *(Ai + lda) * *(Bj + 1)\
+ *(Ai + lda + lda) * *(Bj + 2)\
+ *(Ai + lda + lda + lda) * *(Bj + 3);\
} while(0)
#define NEXT_CI(j) do {\
register int jlda = j * lda;\
register double cij = *(Ci + jlda);\
Bj = (B + (jlda));\
if(K <= 1) {\
if(K == 1) {\
NEXT_CIJ1();\
}\
} else if(K == 2) {\
NEXT_CIJ2();\
} else if(K == 3) {\
NEXT_CIJ3();\
} else {\
NEXT_CIJ4();\
} \
*(Ci + jlda) = cij;\
} while(0)
#define NEXT_C(i) do {\
Ci = &C[i];\
Ai = &A[i];\
if(N <= 1) {\
if(N == 1) {\
NEXT_CI(0);\
}\
} else if(N == 2) {\
NEXT_CI(0);\
NEXT_CI(1);\
} else if(N == 3) {\
NEXT_CI(0);\
NEXT_CI(1);\
NEXT_CI(2);\
} else {\
NEXT_CI(0);\
NEXT_CI(1);\
NEXT_CI(2);\
NEXT_CI(3);\
}\
} while(0)
/* This auxiliary subroutine performs a smaller dgemm operation
* C := C + A * B
* where C is M-by-N, A is M-by-K, and B is K-by-N. */
inline static void do_block(int lda, int M, int N, int K, double* A, double* B, double* C) {
double * Ci;
double * Ai;
double * Bj;
if(M <= 1) {
if(M == 1) {
NEXT_C(0);
}
} else if(M == 2) {
NEXT_C(0);
NEXT_C(1);
} else if(M == 3) {
NEXT_C(0);
NEXT_C(1);
NEXT_C(2);
} else {
NEXT_C(0);
NEXT_C(1);
NEXT_C(2);
NEXT_C(3);
}
}
#define COL(i,j) do {\
Acol = _mm256_loadu_pd(A + (j*n));\
Bcol = _mm256_broadcast_sd(B+j+(i*n));\
Ccol = _mm256_add_pd(Ccol,_mm256_mul_pd(Acol, Bcol));\
} while(0)
#define SQ(i) do {\
register double * Cin = C + (i * n);\
Ccol = _mm256_loadu_pd(Cin);\
COL(i,0);\
COL(i,1);\
COL(i,2);\
COL(i,3);\
_mm256_storeu_pd(Cin, Ccol);\
} while(0)
#define ONE_WIDE(i) do {\
register double * Cin = C + (i * n);\
Ccol = _mm256_loadu_pd(Cin);\
COL(i,0);\
_mm256_storeu_pd(Cin, Ccol);\
} while(0)
#define TWO_WIDE(i) do {\
register double * Cin = C + (i * n);\
Ccol = _mm256_loadu_pd(Cin);\
COL(i,0);\
COL(i,1);\
_mm256_storeu_pd(Cin, Ccol);\
} while(0)
#define THREE_WIDE(i) do {\
register double * Cin = C + (i * n);\
Ccol = _mm256_loadu_pd(Cin);\
COL(i,0);\
COL(i,1);\
COL(i,2);\
_mm256_storeu_pd(Cin, Ccol);\
} while(0)
#define DO_WIDE(type) do {\
type##_WIDE(0);\
type##_WIDE(1);\
type##_WIDE(2);\
type##_WIDE(3);\
} while(0)
inline static void do_blockMNK(int n,double*A, double*B, double*C) {
register __m256d Acol;
register __m256d Bcol;
register __m256d Ccol;
SQ(0);
SQ(1);
SQ(2);
SQ(3);
}
inline static void do_blockMN(int n, int K, double*A,double*B, double*C) {
register __m256d Acol;
register __m256d Bcol;
register __m256d Ccol;
if (K <= 1) {
if (K == 1) {
DO_WIDE(ONE);
}
} else if (K == 2) {
DO_WIDE(TWO);
} else {
DO_WIDE(THREE);
}
}
inline static void do_blockMK(int n, int N, double*A,double*B, double*C) {
register __m256d Acol;
register __m256d Bcol;
register __m256d Ccol;
if (N <= 1) {
if (N == 1) {
SQ(0);
}
} else if (N == 2) {
SQ(0);
SQ(1);
} else {
SQ(0);
SQ(1);
SQ(2);
}
}
inline static void do_blockM(int n, int X, double*A,double*B, double*C) {
register __m256d Acol;
register __m256d Bcol;
register __m256d Ccol;
if (X <= 1) {
if (X == 1) {
Ccol = _mm256_loadu_pd(C);
COL(0, 0);
_mm256_storeu_pd(C, Ccol);
}
} else if (X == 2) {
Ccol = _mm256_loadu_pd(C);
COL(0, 0);
COL(0, 1);
_mm256_storeu_pd(C, Ccol);
Ccol = _mm256_loadu_pd(C + n);
COL(1, 0);
COL(1, 1);
_mm256_storeu_pd(C + n, Ccol);
} else {
Ccol = _mm256_loadu_pd(C);
COL(0, 0);
COL(0, 1);
COL(0, 2);
_mm256_storeu_pd(C, Ccol);
Ccol = _mm256_loadu_pd(C + n);
COL(1, 0);
COL(1, 1);
COL(1, 2);
_mm256_storeu_pd(C + n, Ccol);
Ccol = _mm256_loadu_pd(C + n + n);
COL(2, 0);
COL(2, 1);
COL(2, 2);
_mm256_storeu_pd(C + n + n, Ccol);
}
}
/* This routine performs a dgemm operation
* p
* C := C + A * B
* where A, B, and C are lda-by-lda matrices stored in column-major format.
* On exit, A and B maintain their input values. */
void square_dgemm (int lda, double* _A, double* _B, double* _C) {
const int sqLda = lda * lda;
const int sqLdaSz = sqLda * sizeof(double);
double *A = malloc(sqLdaSz);
double *B = malloc(sqLdaSz);
double *C = _C;
memcpy(A, _A, sqLdaSz);
memcpy(B, _B, sqLdaSz);
const int smallDim = lda&3;//lda%4
const int lastOffset = lda&(-4);//lda-smallDim
const int lastTIMESlda = lastOffset*lda;
const int lastPLUS = lastOffset+lastTIMESlda;
const int JInc = 4 * lda;
const int JBound = sqLda - JInc;
double * Ai;
double * Ci;
double * const BLastLDA = (B + lastTIMESlda);
double * const BLast = (B + lastOffset);
double * const BPlus = (B + lastPLUS);
double * const ALast = (A + lastOffset);
double * const APlus = (A + lastPLUS);
double * const CLast = (C + lastOffset);
double * const CPlus = (C + lastPLUS);
for (int i = 0; i <= lda - 4; i += 4) {
Ai = (A + i);
Ci = (C + i);
for (int jlda = 0; jlda <= JBound; jlda += JInc) {
for (int k = 0, klda = 0; k <= lda - 4; k += 4, klda += JInc)
{
do_blockMNK(lda, Ai + klda, B + k + jlda, Ci + jlda);
}
do_blockMN(lda, smallDim, Ai + lastTIMESlda, BLast + jlda, Ci + jlda);
}
for (int k = 0, klda = 0; k <= lda - 4; k += 4, klda += JInc)
{
do_blockMK(lda, smallDim, Ai + klda, BLastLDA + k, Ci + lastTIMESlda);
}
do_blockM(lda, smallDim, Ai + lastTIMESlda, BPlus, Ci + lastTIMESlda);
}
for (int jlda = 0; jlda <= JBound; jlda += JInc) {
for (int k = 0, klda = 0; k <= lda - 4; k += 4, klda += JInc)
{
do_block(lda, smallDim, 4, 4, ALast + klda, B + k + jlda, CLast + jlda);
}
do_block(lda, smallDim, 4, smallDim, APlus, BLast + jlda, CLast + jlda);
}
for (int k = 0, klda = 0; k <= lda - 4; k += 4, klda += JInc)
{
do_block(lda, smallDim, smallDim, 4, ALast + klda, BLastLDA + k, CPlus);
}
do_block(lda, smallDim, smallDim, smallDim, APlus, BPlus, CPlus);
free(A);
free(B);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment