-
-
Save kevans91/d7213535353b5249705c6fddf199e67b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const char* dgemm_desc = "Simple blocked dgemm."; | |
#include "immintrin.h" | |
#include "stdio.h" | |
#include "string.h" | |
#include "stdlib.h" | |
#define min(a,b) (((a)<(b))?(a):(b)) | |
#define NEXT_CIJ1() do {\ | |
cij += *Ai * *Bj;\ | |
} while(0) | |
#define NEXT_CIJ2() do {\ | |
cij += *Ai * *Bj\ | |
+ *(Ai + lda) * *(Bj + 1);\ | |
} while(0) | |
#define NEXT_CIJ3() do {\ | |
cij += *Ai * *Bj\ | |
+ *(Ai + lda) * *(Bj + 1)\ | |
+ *(Ai + lda + lda) * *(Bj + 2);\ | |
} while(0) | |
#define NEXT_CIJ4() do {\ | |
cij += *Ai * *Bj\ | |
+ *(Ai + lda) * *(Bj + 1)\ | |
+ *(Ai + lda + lda) * *(Bj + 2)\ | |
+ *(Ai + lda + lda + lda) * *(Bj + 3);\ | |
} while(0) | |
#define NEXT_CI(j) do {\ | |
register int jlda = j * lda;\ | |
register double cij = *(Ci + jlda);\ | |
Bj = (B + (jlda));\ | |
if(K <= 1) {\ | |
if(K == 1) {\ | |
NEXT_CIJ1();\ | |
}\ | |
} else if(K == 2) {\ | |
NEXT_CIJ2();\ | |
} else if(K == 3) {\ | |
NEXT_CIJ3();\ | |
} else {\ | |
NEXT_CIJ4();\ | |
} \ | |
*(Ci + jlda) = cij;\ | |
} while(0) | |
#define NEXT_C(i) do {\ | |
Ci = &C[i];\ | |
Ai = &A[i];\ | |
if(N <= 1) {\ | |
if(N == 1) {\ | |
NEXT_CI(0);\ | |
}\ | |
} else if(N == 2) {\ | |
NEXT_CI(0);\ | |
NEXT_CI(1);\ | |
} else if(N == 3) {\ | |
NEXT_CI(0);\ | |
NEXT_CI(1);\ | |
NEXT_CI(2);\ | |
} else {\ | |
NEXT_CI(0);\ | |
NEXT_CI(1);\ | |
NEXT_CI(2);\ | |
NEXT_CI(3);\ | |
}\ | |
} while(0) | |
/* This auxiliary subroutine performs a smaller dgemm operation | |
* C := C + A * B | |
* where C is M-by-N, A is M-by-K, and B is K-by-N. */ | |
inline static void do_block(int lda, int M, int N, int K, double* A, double* B, double* C) { | |
double * Ci; | |
double * Ai; | |
double * Bj; | |
if(M <= 1) { | |
if(M == 1) { | |
NEXT_C(0); | |
} | |
} else if(M == 2) { | |
NEXT_C(0); | |
NEXT_C(1); | |
} else if(M == 3) { | |
NEXT_C(0); | |
NEXT_C(1); | |
NEXT_C(2); | |
} else { | |
NEXT_C(0); | |
NEXT_C(1); | |
NEXT_C(2); | |
NEXT_C(3); | |
} | |
} | |
#define COL(i,j) do {\ | |
Acol = _mm256_loadu_pd(A + (j*n));\ | |
Bcol = _mm256_broadcast_sd(B+j+(i*n));\ | |
Ccol = _mm256_add_pd(Ccol,_mm256_mul_pd(Acol, Bcol));\ | |
} while(0) | |
#define SQ(i) do {\ | |
register double * Cin = C + (i * n);\ | |
Ccol = _mm256_loadu_pd(Cin);\ | |
COL(i,0);\ | |
COL(i,1);\ | |
COL(i,2);\ | |
COL(i,3);\ | |
_mm256_storeu_pd(Cin, Ccol);\ | |
} while(0) | |
#define ONE_WIDE(i) do {\ | |
register double * Cin = C + (i * n);\ | |
Ccol = _mm256_loadu_pd(Cin);\ | |
COL(i,0);\ | |
_mm256_storeu_pd(Cin, Ccol);\ | |
} while(0) | |
#define TWO_WIDE(i) do {\ | |
register double * Cin = C + (i * n);\ | |
Ccol = _mm256_loadu_pd(Cin);\ | |
COL(i,0);\ | |
COL(i,1);\ | |
_mm256_storeu_pd(Cin, Ccol);\ | |
} while(0) | |
#define THREE_WIDE(i) do {\ | |
register double * Cin = C + (i * n);\ | |
Ccol = _mm256_loadu_pd(Cin);\ | |
COL(i,0);\ | |
COL(i,1);\ | |
COL(i,2);\ | |
_mm256_storeu_pd(Cin, Ccol);\ | |
} while(0) | |
#define DO_WIDE(type) do {\ | |
type##_WIDE(0);\ | |
type##_WIDE(1);\ | |
type##_WIDE(2);\ | |
type##_WIDE(3);\ | |
} while(0) | |
inline static void do_blockMNK(int n,double*A, double*B, double*C) { | |
register __m256d Acol; | |
register __m256d Bcol; | |
register __m256d Ccol; | |
SQ(0); | |
SQ(1); | |
SQ(2); | |
SQ(3); | |
} | |
inline static void do_blockMN(int n, int K, double*A,double*B, double*C) { | |
register __m256d Acol; | |
register __m256d Bcol; | |
register __m256d Ccol; | |
if (K <= 1) { | |
if (K == 1) { | |
DO_WIDE(ONE); | |
} | |
} else if (K == 2) { | |
DO_WIDE(TWO); | |
} else { | |
DO_WIDE(THREE); | |
} | |
} | |
inline static void do_blockMK(int n, int N, double*A,double*B, double*C) { | |
register __m256d Acol; | |
register __m256d Bcol; | |
register __m256d Ccol; | |
if (N <= 1) { | |
if (N == 1) { | |
SQ(0); | |
} | |
} else if (N == 2) { | |
SQ(0); | |
SQ(1); | |
} else { | |
SQ(0); | |
SQ(1); | |
SQ(2); | |
} | |
} | |
inline static void do_blockM(int n, int X, double*A,double*B, double*C) { | |
register __m256d Acol; | |
register __m256d Bcol; | |
register __m256d Ccol; | |
if (X <= 1) { | |
if (X == 1) { | |
Ccol = _mm256_loadu_pd(C); | |
COL(0, 0); | |
_mm256_storeu_pd(C, Ccol); | |
} | |
} else if (X == 2) { | |
Ccol = _mm256_loadu_pd(C); | |
COL(0, 0); | |
COL(0, 1); | |
_mm256_storeu_pd(C, Ccol); | |
Ccol = _mm256_loadu_pd(C + n); | |
COL(1, 0); | |
COL(1, 1); | |
_mm256_storeu_pd(C + n, Ccol); | |
} else { | |
Ccol = _mm256_loadu_pd(C); | |
COL(0, 0); | |
COL(0, 1); | |
COL(0, 2); | |
_mm256_storeu_pd(C, Ccol); | |
Ccol = _mm256_loadu_pd(C + n); | |
COL(1, 0); | |
COL(1, 1); | |
COL(1, 2); | |
_mm256_storeu_pd(C + n, Ccol); | |
Ccol = _mm256_loadu_pd(C + n + n); | |
COL(2, 0); | |
COL(2, 1); | |
COL(2, 2); | |
_mm256_storeu_pd(C + n + n, Ccol); | |
} | |
} | |
/* This routine performs a dgemm operation | |
* p | |
* C := C + A * B | |
* where A, B, and C are lda-by-lda matrices stored in column-major format. | |
* On exit, A and B maintain their input values. */ | |
void square_dgemm (int lda, double* _A, double* _B, double* _C) { | |
const int sqLda = lda * lda; | |
const int sqLdaSz = sqLda * sizeof(double); | |
double *A = malloc(sqLdaSz); | |
double *B = malloc(sqLdaSz); | |
double *C = _C; | |
memcpy(A, _A, sqLdaSz); | |
memcpy(B, _B, sqLdaSz); | |
const int smallDim = lda&3;//lda%4 | |
const int lastOffset = lda&(-4);//lda-smallDim | |
const int lastTIMESlda = lastOffset*lda; | |
const int lastPLUS = lastOffset+lastTIMESlda; | |
const int JInc = 4 * lda; | |
const int JBound = sqLda - JInc; | |
double * Ai; | |
double * Ci; | |
double * const BLastLDA = (B + lastTIMESlda); | |
double * const BLast = (B + lastOffset); | |
double * const BPlus = (B + lastPLUS); | |
double * const ALast = (A + lastOffset); | |
double * const APlus = (A + lastPLUS); | |
double * const CLast = (C + lastOffset); | |
double * const CPlus = (C + lastPLUS); | |
for (int i = 0; i <= lda - 4; i += 4) { | |
Ai = (A + i); | |
Ci = (C + i); | |
for (int jlda = 0; jlda <= JBound; jlda += JInc) { | |
for (int k = 0, klda = 0; k <= lda - 4; k += 4, klda += JInc) | |
{ | |
do_blockMNK(lda, Ai + klda, B + k + jlda, Ci + jlda); | |
} | |
do_blockMN(lda, smallDim, Ai + lastTIMESlda, BLast + jlda, Ci + jlda); | |
} | |
for (int k = 0, klda = 0; k <= lda - 4; k += 4, klda += JInc) | |
{ | |
do_blockMK(lda, smallDim, Ai + klda, BLastLDA + k, Ci + lastTIMESlda); | |
} | |
do_blockM(lda, smallDim, Ai + lastTIMESlda, BPlus, Ci + lastTIMESlda); | |
} | |
for (int jlda = 0; jlda <= JBound; jlda += JInc) { | |
for (int k = 0, klda = 0; k <= lda - 4; k += 4, klda += JInc) | |
{ | |
do_block(lda, smallDim, 4, 4, ALast + klda, B + k + jlda, CLast + jlda); | |
} | |
do_block(lda, smallDim, 4, smallDim, APlus, BLast + jlda, CLast + jlda); | |
} | |
for (int k = 0, klda = 0; k <= lda - 4; k += 4, klda += JInc) | |
{ | |
do_block(lda, smallDim, smallDim, 4, ALast + klda, BLastLDA + k, CPlus); | |
} | |
do_block(lda, smallDim, smallDim, smallDim, APlus, BPlus, CPlus); | |
free(A); | |
free(B); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment