Skip to content

Instantly share code, notes, and snippets.

@kevinronan
Last active January 12, 2023 06:43
Show Gist options
  • Save kevinronan/24367280d31d5bd753b5d279e565e2c5 to your computer and use it in GitHub Desktop.
Save kevinronan/24367280d31d5bd753b5d279e565e2c5 to your computer and use it in GitHub Desktop.
Call a function with a double** in Cython, Pybind11, and CFFI from numpy

Cython

populate_.h

#ifndef PROGRAM_HEADER
#define PROGRAM_HEADER

void populate(double ** arr, int len_x, int len_y);

#endif

populate_.cpp

#include <iostream>
#include "populate_.h"

using namespace std;


/**
    Replace data in the array with some values.
*/
void populate(double ** arr, int len_x, int len_y){
    for(auto i = 0; i < len_x; i++){
        for(auto j = 0; j < len_y; j++){
            arr[i][j] = 10 + (i+2*j);
        }
    }
}

populate.pyx

""" Interface to C++ code """
from libc.stdlib cimport malloc, free

cimport numpy as np

cdef extern from "populate_.h":
    void populate(double ** arr, int len_x, int len_y);

def populate_array(np.ndarray[double, ndim=2, mode="c"] arr not None):
    """ Populate the array with some data """
    # Get a memory view of the array
    cdef double[:, :] cython_view = arr

    # Create a double ** view of the data
    cdef double **ptr = <double **>malloc(arr.shape[0] * sizeof(double*))
    if not ptr:
        raise MemoryError()
    try:
        for idx in range(arr.shape[0]):
            ptr[idx] = <double *>&cython_view[idx, 0]
        populate(ptr, arr.shape[0], arr.shape[1])
    finally:
        free(ptr)

setup.py

"""
Use this to compile pyx to a shared object
python setup.py build_ext --inplace
"""
import numpy
from Cython.Distutils import build_ext
from distutils.core import setup
from distutils.extension import Extension


ext_modules=[
    Extension(
        "populate", # Name of extension
        # Filename of Cython source, pyx name must be same as extension name
        # cpp file should not: populate.pyx will be used to generate populate.cpp.
        sources=["populate.pyx", "populate_.cpp"],
        language="c++",
        include_dirs=[numpy.get_include()], # Needed to use numpy in cython
        extra_compile_args=['-std=c++11', "-O0"], # C++11 and no optimization to allow debugging
    )
]

setup(
    name = "populate",
    cmdclass = {"build_ext": build_ext},
    ext_modules = ext_modules
)

To compile:

python setup.py build_ext --inplace

To run:

import populate
import numpy as np
arr = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
populate.populate_array(arr)
print(arr)

Pybind11

populate.cpp

#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <iostream>

using namespace std;

namespace py = pybind11;


/**
    Replace data in the array with some values.
*/
void add_data(double ** arr, int len_x, int len_y){
    for(auto i = 0; i < len_x; i++){
        for(auto j = 0; j < len_y; j++){
            arr[i][j] = 10 + (i+2*j);
        }
    }
}


/**
    Get a double ** pointing to the data in a 2D numpy array.
    Don't forget to delete the returned pointer!
*/
double ** double_array(py::array_t<double, 8> input) {
    auto buf = input.request(true);

    if (buf.ndim != 2)
        throw std::runtime_error("Number of dimensions must be 2");

    int len_x = buf.shape[0];

    double * fist_elem = (double *) buf.ptr;
    double ** ptr = new double *[len_x];
    for(int i = 0; i < len_x; i++){
        ptr[i] = (double *)(fist_elem + i*buf.strides[0]/buf.itemsize);
    }
    return ptr;
}

PYBIND11_MAKE_OPAQUE(py::array_t<double>); // Make it possible to share memory
PYBIND11_PLUGIN(populate) {
    py::module m("populate");
    m.def("populate_array",
         [](py::array_t<double, 8> input) {
            auto buf = input.request();
            double ** ptr = double_array(input);
            add_data(ptr, (signed int)buf.shape[0], buf.shape[1]);
            delete[] ptr;
         }, "Replace data in a 2D Numpy array");
    return m.ptr();
}

To compile:

export PYBIND_PATH=/path/to/pybind/includes/
c++ -O3 -shared -std=c++11 -I $PYBIND_PATH `python-config --cflags --ldflags` populate.cpp -o populate.so -fPIC

To run:

import populate
import numpy as np
arr = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
populate.populate_array(arr)
print(arr)

CFFI

ABI level, in-line

populate.h

#ifndef PROGRAM_HEADER
#define PROGRAM_HEADER

extern "C" {
void populate(double ** arr, int len_x, int len_y);
}

#endif

populate.cpp

#include <iostream>
#include "populate.h"

using namespace std;


/**
    Replace data in the array with some values.
*/
void populate(double ** arr, int len_x, int len_y){
    for(auto i = 0; i < len_x; i++){
        for(auto j = 0; j < len_y; j++){
            arr[i][j] = 10 + (i+2*j);
        }
    }
}

wrapper.py

""" Python CFFI wrapper to populate.so """
import cffi
FFI_ = cffi.FFI()

def get_header_declarations(header_path):
    """ Return extern C declarations from the header file. """
    with open(header_path) as header_file:
        lines = header_file.readlines()
    extern_decl = []
    extern = False
    for line in lines:
        if line.strip() == 'extern "C" {':
            extern = True
        elif extern:
            if line.strip() == '}':
                extern = False
            else:
                extern_decl.append(line)
    return '\n'.join(extern_decl)

FFI_.cdef(get_header_declarations('populate.h'))
LIB = FFI_.dlopen('./populate.so')


def get_ptr_from_2d_numpy(arr):
    """
        Create CFFI pointers to the data in the numpy array.

        :returns: The cffi pointers referring to the data in the numpy array.
    """
    values_c = FFI_.new("double* []", arr.shape[0])
    for idx in range(arr.shape[0]):
        values_c[idx] = FFI_.cast("double *", arr[idx].ctypes.data)
    return values_c


def populate_array(arr):
    """ Replace data in a 2D Numpy array. """
    ptr = get_ptr_from_2d_numpy(arr)
    LIB.populate(ptr, arr.shape[0], arr.shape[1])

To compile:

c++ -O3 -shared -std=c++11 populate.cpp -o populate.so -fPIC

To run:

import wrapper as populate
import numpy as np
arr = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
populate.populate_array(arr)
print(arr)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment