Skip to content

Instantly share code, notes, and snippets.

@ChadFulton
Created June 4, 2014 16:31
Show Gist options
  • Save ChadFulton/36aac8be257f6855e54c to your computer and use it in GitHub Desktop.
Save ChadFulton/36aac8be257f6855e54c to your computer and use it in GitHub Desktop.
# Typical imports
from cpython cimport PyCObject_AsVoidPtr
import numpy as np
cimport numpy as np
cimport cython
# BLAS / LAPACK functions
from blas_lapack cimport *
from scipy.linalg import blas, lapack
cpdef fused blas_numeric:
np.float32_t
double
np.complex64_t
complex
cdef sgemv_t *sgemv = <sgemv_t*>PyCObject_AsVoidPtr(blas.sgemv._cpointer)
cdef dgemv_t *dgemv = <dgemv_t*>PyCObject_AsVoidPtr(blas.dgemv._cpointer)
cdef cgemv_t *cgemv = <cgemv_t*>PyCObject_AsVoidPtr(blas.cgemv._cpointer)
cdef zgemv_t *zgemv = <zgemv_t*>PyCObject_AsVoidPtr(blas.zgemv._cpointer)
cdef void gemv(char *TRANS, int *M, int *N,
blas_numeric *ALPHA, blas_numeric *A, int *LDA,
blas_numeric *X, int *INCX,
blas_numeric *BETA, blas_numeric *Y, int *INCY):
if blas_numeric is np.float32_t:
sgemv(TRANS, M, N, ALPHA, A, LDA, X, INCX, BETA, Y, INCY)
elif blas_numeric is double:
dgemv(TRANS, M, N, ALPHA, A, LDA, X, INCX, BETA, Y, INCY)
elif blas_numeric is np.complex64_t:
cgemv(TRANS, M, N, ALPHA, A, LDA, X, INCX, BETA, Y, INCY)
elif blas_numeric is complex:
zgemv(TRANS, M, N, ALPHA, A, LDA, X, INCX, BETA, Y, INCY)
cpdef test(blas_numeric [::1,:] A, blas_numeric [:] x, int iterations):
cdef int n, k
n = A.shape[0]
k = A.shape[1]
cdef blas_numeric [:] y
cdef int inc = 1, i
cdef char trans = "N"
cdef blas_numeric alpha=1.0, beta=0.0
y = np.zeros((n,), dtype=float)
for i in range(iterations):
gemv(&trans, &n, &k, &alpha, &A[0,0], &n, &x[0], &inc, &beta, &y[0], &inc)
return np.array(y)
cpdef test2(double [::1,:] A, double [:] x, int iterations):
cdef int n, k
n = A.shape[0]
k = A.shape[1]
cdef double [:] y
cdef int inc = 1, i
cdef char trans = "N"
cdef double alpha=1.0, beta=0.0
y = np.zeros((n,), dtype=float)
for i in range(iterations):
dgemv(&trans, &n, &k, &alpha, &A[0,0], &n, &x[0], &inc, &beta, &y[0], &inc)
return np.array(y)
def main():
A = np.array([[1.,1],
[2,2]], order="F")
x = np.array([1.,2])
print np.dot(A, x)
# %timeit test(A, x, 100000)
# %timeit test2(A, x, 100000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment