Skip to content

Instantly share code, notes, and snippets.

@pv
Last active October 23, 2021 11:34
  • Star 17 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save pv/5437087 to your computer and use it in GitHub Desktop.
import numpy as np
import scipy.linalg.blas
cdef extern from "f2pyptr.h":
void *f2py_pointer(object) except NULL
ctypedef int dgemm_t(
char *transa, char *transb,
int *m, int *n, int *k,
double *alpha,
double *a, int *lda,
double *b, int *ldb,
double *beta,
double *c, int *ldc)
# Since Scipy >= 0.12.0
cdef dgemm_t *dgemm = <dgemm_t*>f2py_pointer(scipy.linalg.blas.dgemm._cpointer)
def myfunc():
cdef double[::1,:] a, b, c
cdef int m, n, k, lda, ldb, ldc
cdef double alpha, beta
a = np.array([[1, 2], [3, 4]], float, order="F")
b = np.array([[5, 6], [7, 8]], float, order="F")
c = np.empty((2, 2), float, order="F")
alpha = 1.0
beta = 0.0
lda = 2
ldb = 2
ldc = 2
m = 2
n = 2
k = 2
dgemm("N", "N", &m, &n, &k, &alpha, &a[0,0], &lda, &b[0,0], &ldb, &beta, &c[0,0], &ldc)
print(np.asarray(c))
print(np.dot(a, b))
#ifndef F2PYPTR_H_
#define F2PYPTR_H_
#include <Python.h>
void *f2py_pointer(PyObject *obj)
{
#if PY_VERSION_HEX < 0x03000000
if (PyCObject_Check(obj)) {
return PyCObject_AsVoidPtr(obj);
}
#endif
#if PY_VERSION_HEX >= 0x02070000
if (PyCapsule_CheckExact(obj)) {
return PyCapsule_GetPointer(obj, NULL);
}
#endif
PyErr_SetString(PyExc_ValueError, "Not an object containing a void ptr");
return NULL;
}
#endif
@prabhuramachandran
Copy link

Thanks! This is super useful. FWIW, one could roll the f2pyptr.h into the example.pyx with this (it makes it easier to play with one file rather than two):

from cpython cimport (PY_VERSION_HEX, PyCObject_Check,
    PyCObject_AsVoidPtr, PyCapsule_CheckExact, PyCapsule_GetPointer)

cdef void* f2py_pointer(obj):
    if PY_VERSION_HEX < 0x03000000:
        if (PyCObject_Check(obj)):
            return PyCObject_AsVoidPtr(obj)
    elif PY_VERSION_HEX >= 0x02070000:
        if (PyCapsule_CheckExact(obj)):
            return PyCapsule_GetPointer(obj, NULL);
    raise ValueError("Not an object containing a void ptr")

@VoodooChild83
Copy link

VoodooChild83 commented Nov 25, 2016

I'm new to this, but what is the difference between the above and the following,

from scipy.linalg.cython_blas import dgemm

which avoids having to work through fortran explicitly?

(I get about the same speed between your example and importing dgemm from cython_blas.)

@mrslezak
Copy link

mrslezak commented Jul 18, 2017

I think this is the more appropriate way to call the built in cython_blas libs: https://stackoverflow.com/questions/44980665/using-the-scipy-cython-blas-interface-from-cython-not-working-on-vectors-mx1-1xn nevermind all the error handling around the type of input contiguous array (F or C) could just be removed. The above using f2pyptr.h is ignoring the SciPy direct link to Fortran which avoids that step. And this link is actually in the SciPy examples of how to call the interface. Just saying.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment