Created
June 28, 2018 20:55
-
-
Save simonwhitaker/4421156c8baeacc55d1dce8e775d136d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
#include <Accelerate/Accelerate.h> | |
void print_matrix(const double *matrix, int rows, int cols) { | |
// for debugging | |
for (int i = 0; i < rows * cols; i++) { | |
printf("%.2f", matrix[i]); | |
if (i % cols == cols - 1) { | |
printf("\n"); | |
} else { | |
printf(" "); | |
} | |
} | |
} | |
void dumb_matrix_multiply(const double *A, const double *B, double *C, const int ar, const int ac, const int bc) { | |
int cr = ar; | |
int cc = bc; | |
for (int row = 0; row < cr; row++) { | |
for (int col = 0; col < cc; col++) { | |
int c_index = row * cc + col; | |
double value = 0.0; | |
for (int i = 0; i < ac; i++) { | |
int a_index = row * ac + i; | |
int b_index = col + bc * i; | |
value += A[a_index] * B[b_index]; | |
} | |
C[c_index] = value; | |
} | |
} | |
} | |
int main(int argc, const char * argv[]) { | |
if (argc < 4) { | |
exit(1); | |
} | |
const int ar = atoi(argv[1]); // rows in A | |
const int ac = atoi(argv[2]); // columns in A | |
const int bc = atoi(argv[3]); // columns in B | |
const int br = ac; // rows in B | |
const int cr = ar; // rows in C | |
const int cc = bc; // columns in C | |
double *const A = (double *)calloc(ar * ac, sizeof(double)); | |
double *const B = (double *)calloc(br * bc, sizeof(double)); | |
double *const C = (double *)calloc(cr * cc, sizeof(double)); | |
// Initialise A and B with incrementing values, A[0] = 0.0, A[1] = 1.0, etc. Same for B. | |
for (int i = 0; i < ar * ac; i++) { A[i] = (double)i; } | |
for (int i = 0; i < br * bc; i++) { B[i] = (double)i; } | |
if (argc == 4) { | |
printf("algorithm: cblas_dgemm\n"); | |
const double alpha = 1.0, beta = 1.0; | |
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, cr, cc, ac, alpha, A, ac, B, bc, beta, C, cc); | |
printf("Done! (%.1f)", C[0]); | |
} else { | |
printf("algorithm: dumb_matrix_multiply\n"); | |
dumb_matrix_multiply(A, B, C, ar, ac, bc); | |
printf("Done! (%.1f)", C[0]); | |
} | |
free(A); | |
free(B); | |
free(C); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment