Skip to content

Instantly share code, notes, and snippets.

@JonathanRaiman
Last active August 29, 2015 14:08
Show Gist options
  • Save JonathanRaiman/07046b897709fffb49e5 to your computer and use it in GitHub Desktop.
Save JonathanRaiman/07046b897709fffb49e5 to your computer and use it in GitHub Desktop.
Cython & BLAS gemm
# How to get gemm to work in Cython
# 1. suppose your data is Fortran contiguous (really ?)
# then blas supports this out of the box:
%%cython
cimport numpy as np
import numpy as np
from cpython cimport PyCapsule_GetPointer # PyCObject_AsVoidPtr
from scipy.linalg.blas import fblas
from cython_lstm import vector_outer_product
cdef int ONE = 1
cdef float ONE_f = 1.0
cdef float ZERO_f = 0.0
REAL = np.float32
ctypedef np.float32_t REAL_t
# check for fortan here:
cdef extern from "numpy/arrayobject.h":
cdef bint PyArray_IS_F_CONTIGUOUS(np.ndarray) nogil
# create the pointers to the BLAS functions
ctypedef void (*sgemm_ptr) (char *transA, char *transB, \
int *m, int *n, int *k,\
float *alpha,\
float *a, int *lda,\
float *b, int *ldb,\
float *beta, \
float *c, int *ldc)
cdef sgemm_ptr sgemm=<sgemm_ptr>PyCapsule_GetPointer(fblas.sgemm._cpointer, NULL)
# with those pointers we can now wrap a cython function for these
def dot(np.ndarray[REAL_t, ndim=2] _a, np.ndarray[REAL_t, ndim=2] _b,
REAL_t alpha=1., REAL_t beta=0.):
cdef int m, n, k, lda, ldb, ldc
cdef char * transA = &trans
cdef char * transB = &trans
cdef REAL_t * a
cdef REAL_t * b
cdef REAL_t * c
if PyArray_IS_F_CONTIGUOUS(_a):
transA = &n_trans
else:
# when not fortran we can also transpose
# to coerce matrix into fortran
_a = _a.T
if PyArray_IS_F_CONTIGUOUS(_b):
transB = &n_trans
else:
# when not fortran we can also transpose
# to coerce matrix into fortran
_b = _b.T
if transA[0] == n_trans:
m = _a.shape[0]
k = _a.shape[1]
n = _b.shape[1]
else:
m = _a.shape[1]
k = _a.shape[0]
n = _b.shape[0]
cdef np.ndarray[REAL_t, ndim=2] _c = np.zeros((m,n), dtype=REAL, order="F")
a = <REAL_t *>np.PyArray_DATA(_a)
b = <REAL_t *>np.PyArray_DATA(_b)
c = <REAL_t *>np.PyArray_DATA(_c)
with nogil:
# some of the operations above
# are gil needy and thus
# only this last chunk can be "ungiled"
# when life give you lemons make lemonade
lda = _a.shape[0]
ldb = _b.shape[0]
ldc = _c.shape[0]
sgemm(transA, transB, &m, &n, &k, &alpha, &a[0], &lda, &b[0], &ldb,
&beta, &c[0], &ldc)
return _c
# we can test this:
def fortran_arrays(x,y):
return (np.asfortranarray(np.random.randn(x,y).astype(np.float32)),
np.asfortranarray(np.random.randn(y,x).astype(np.float32)))
def ordinary_arrays(x,y):
return (np.random.randn(x,y).astype(np.float32),
np.random.randn(y,x).astype(np.float32))
def test_dot()
fworks = []
fdoesnt = 0
works = []
doesnt = 0
for i in range(1, 10):
for j in range(1, 10):
a, b = fortran_arrays(i,j)
c = dot(a,b)
d = np.dot(a,b)
try:
if np.allclose(c, d):
fworks.append((i,j))
else:
fdoesnt += 1
except (TypeError, ValueError):
doesnt += 1
pass
a, b = ordinary_arrays(i,j)
c = dot(a,b)
d = np.dot(a,b)
try:
if np.allclose(c, d):
works.append((i,j))
else:
doesnt += 1
except (TypeError, ValueError):
doesnt += 1
pass
return (len(fworks), fdoesnt, len(works), doesnt)
test_dot()
# => (81, 0, 81, 0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment