Created
January 24, 2012 18:04
-
-
Save anonymous/1671571 to your computer and use it in GitHub Desktop.
Test program illustrating problem with cblas_sgemm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
$ gcc -g -c test_cblas_sgemm.c | |
test_cblas_sgemm.c:20:3: warning: initialization from incompatible pointer type [enabled by default] | |
test_cblas_sgemm.c:20:3: warning: (near initialization for ‘GemmFuncs[1]’) [enabled by default] | |
test_cblas_sgemm.c:25:1: warning: excess elements in array initializer [enabled by default] | |
test_cblas_sgemm.c:25:1: warning: (near initialization for ‘GemmFuncs’) [enabled by default] | |
$ gcc test_cblas_sgemm.o -L/usr/local/atlas/lib/ -lcblas -llapack -latlas -lpthread -lrt -ldl -lcrypt -lm -lc | |
$ ./a.out | |
ldc must be >= MAX(N,1): ldc=0 N=2Parameter 14 to routine cblas_sgemm was incorrect |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdlib.h> | |
#include <cblas.h> | |
#include <assert.h> | |
typedef float decimal; | |
enum DTypes { | |
NM_NONE, | |
NM_FLOAT32, | |
NM_FLOAT64, | |
NM_COMPLEX64, | |
NM_COMPLEX128, | |
NM_TYPES | |
}; | |
typedef void (*nm_gemm_t[NM_TYPES])(); | |
nm_gemm_t GemmFuncs = { // by NM_TYPES | |
NULL, | |
cblas_sgemm, // NM_FLOAT32 <--- LINE 20: incompatible pointer type | |
cblas_dgemm, // NM_FLOAT64 | |
cblas_cgemm, // NM_COMPLEX64 | |
cblas_zgemm, // NM_COMPLEX128 | |
NULL // NM_TYPES | |
}; | |
int main() { | |
size_t* shape_A = malloc(2 * sizeof(size_t)); | |
size_t* shape_B = malloc(2 * sizeof(size_t)); | |
size_t* shape_C = malloc(2 * sizeof(size_t)); | |
int dtype = NM_FLOAT32; // works fine as NM_FLOAT64 | |
decimal *A = malloc(12 * sizeof(decimal)); | |
decimal *B = malloc(6 * sizeof(decimal)); | |
decimal *C = malloc(8 * sizeof(decimal)); | |
shape_A[0] = 4; shape_A[1] = 3; | |
shape_B[0] = 3; shape_B[1] = 2; | |
shape_C[0] = shape_A[0]; shape_C[1] = shape_B[1]; | |
A[0] = 14.0; A[1] = 9.0; A[2] = 3.0; // 4x3 | |
A[3] = 2.0; A[4] = 11.0; A[5] = 15.0; | |
A[6] = 0.0; A[7] = 12.0; A[8] = 17.0; | |
A[9] = 5.0; A[10]= 2.0; A[11]= 3.0; | |
B[0] = 12.0; B[1] = 25.0; // 3x2 | |
B[2] = 9.0; B[3] = 10.0; | |
B[4] = 8.0; B[5] = 5.0; | |
GemmFuncs[dtype]( | |
CblasRowMajor, CblasNoTrans, CblasNoTrans, | |
shape_C[0], shape_C[1], shape_A[1], | |
1.0, | |
A, shape_A[1], | |
B, shape_B[1], | |
0.0, | |
C, | |
shape_C[1] // parameter 14 | |
); | |
assert( C[0] == 273.0 ); assert( C[1] == 455.0 ); | |
assert( C[2] == 243.0 ); assert( C[3] == 235.0 ); | |
assert( C[4] == 244.0 ); assert( C[5] == 205.0 ); | |
assert( C[6] == 102.0 ); assert( C[7] == 160.0 ); | |
free(A); free(B); free(C); | |
free(shape_A); free(shape_B); free(shape_C); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment