Created
July 25, 2015 06:01
-
-
Save sterin/50dd65256de093de1bb3 to your computer and use it in GitHub Desktop.
Example for Stackoverflow question 31581722 http://stackoverflow.com/questions/31581722/passing-a-list-and-numpy-matrix-to-a-python-function-from-a-c-application
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
cmake_minimum_required(VERSION 2.8.4) | |
project(so_31581722) | |
find_package(PythonLibs REQUIRED) | |
include_directories(${PYTHON_INCLUDE_DIRS}) | |
set(SOURCE_FILES main.cpp) | |
add_executable(so_31581722 ${SOURCE_FILES}) | |
target_link_libraries(so_31581722 ${PYTHON_LIBRARIES}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Author: Baruch Sterin <baruchs@gmail.com> | |
// Python headers | |
#include <Python.h> | |
#include <abstract.h> | |
// NumPy C/API headers | |
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION // remove warnings | |
#include <numpy/ndarrayobject.h> | |
#include <vector> | |
int main() | |
{ | |
// initialize python | |
Py_InitializeEx(1); | |
// import our test module | |
PyObject* numpy_test_module = PyImport_ImportModule("numpy_test"); | |
// retrieve 'print_matrix(); from our module | |
PyObject* print_matrix = PyObject_GetAttrString(numpy_test_module, "print_matrix"); | |
// retrieve 'some_function' from our module | |
PyObject* transform_matrix = PyObject_GetAttrString(numpy_test_module, "transform_matrix"); | |
// no longer need to reference the module directly | |
Py_XDECREF(numpy_test_module); | |
// initialize numpy array library | |
import_array1(-1); // returns -1 on failure | |
// create a new numpy array | |
// array dimensions | |
npy_intp dim[] = {5, 5}; | |
// array data | |
std::vector<double> buffer(25, 1.0); | |
// create a new array using 'buffer' | |
PyObject* array_2d = PyArray_SimpleNewFromData(2, dim, NPY_DOUBLE, &buffer[0]); | |
// print the array by calling 'print_matrix' | |
PyObject* return_value1 = PyObject_CallFunction(print_matrix, "O", array_2d); | |
// we don't need the return value, release the reference | |
Py_XDECREF(return_value1); | |
// create list | |
PyObject* list = PyList_New(3); | |
PyList_SetItem(list, 0, PyLong_FromLong(2)); | |
PyList_SetItem(list, 1, PyLong_FromLong(3)); | |
PyList_SetItem(list, 2, PyLong_FromLong(4)); | |
// call the function with the array as its parameter | |
PyObject* transformed_matrix = PyObject_CallFunction(transform_matrix, "OO", array_2d, list); | |
// no longer need the list, free the reference | |
Py_XDECREF(list); | |
// print the returned array by calling 'print_matrix' | |
PyObject* return_value2 = PyObject_CallFunction(print_matrix, "O", transformed_matrix); | |
// no longer need the 'return_value2', release the reference | |
Py_XDECREF(return_value2); | |
// no longer need 'transformed_matrix' | |
Py_XDECREF(transformed_matrix); | |
// no longer need the array | |
Py_XDECREF(array_2d); | |
// no longer need the reference to transform_matrix | |
Py_XDECREF(transform_matrix); | |
// no longer need the reference to 'print_matrix' | |
Py_XDECREF(print_matrix); | |
// clean up python | |
Py_Finalize(); | |
return 0; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy | |
def print_matrix(M): | |
print (M) | |
def transform_matrix(M, L): | |
for x in L: | |
M = M*x | |
return M | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment