Skip to content

Instantly share code, notes, and snippets.

@ewmoore
Created March 19, 2015 13:27
Show Gist options
  • Save ewmoore/e88ee9dd84c1d9d58892 to your computer and use it in GitHub Desktop.
Save ewmoore/e88ee9dd84c1d9d58892 to your computer and use it in GitHub Desktop.
#define NPY_NO_DEPRECATED_API NPY_API_VERSION
#include <Python.h>
#include <numpy/arrayobject.h>
#include "numpy/npy_3kcompat.h"
PyObject* sum_and_prod(PyArrayObject *arr, int *axis, int naxis)
{
NpyIter *iter_outer, *iter_inner;
NpyIter_IterNextFunc *iternext_outer, *iternext_inner;
char **dataptr_outer, **dataptr_inner;
npy_intp *strideptr_outer, *strideptr_inner;
npy_intp *innersizeptr_outer, *innersizeptr_inner;
PyArrayObject *ops[2];
PyObject *out;
npy_uint32 outer_flags = 0;
npy_uint32 op_flags_outer[2];
int oa_ndim_outer;
int *oa_axes_outer[2];
int arr_axes_outer[NPY_MAXDIMS];
int out_axes_outer[NPY_MAXDIMS];
//PyArrayDescr *dtype_outer[2];
npy_intp out_shape[NPY_MAXDIMS];
int i,j,k,m,n,flag;
npy_uint32 inner_flags = NPY_ITER_EXTERNAL_LOOP | NPY_ITER_BUFFERED | NPY_ITER_GROWINNER;
npy_uint32 op_flags_inner[1];
int oa_ndim_inner;
int *oa_axes_inner[1];
int arr_axes_inner[NPY_MAXDIMS];
PyArray_Descr *dtype_inner[1];
PyArray_Descr *double_dtype = PyArray_DescrFromType(NPY_DOUBLE);
ops[0] = arr;
op_flags_outer[0] = NPY_ITER_READONLY;
op_flags_inner[0] = NPY_ITER_READONLY | NPY_ITER_NBO;
//dtype_outer[0] = PyArray_DESCR(arr);
dtype_inner[0] = double_dtype;
op_flags_outer[1] = NPY_ITER_WRITEONLY;
//dtype_outer[1] = double_dtype;
if (naxis <= 0) {
PyErr_SetString(PyExc_ValueError, "Need at at least 1 axis");
Py_DECREF(double_dtype);
}
if (naxis > PyArray_NDIM(arr)) {
PyErr_SetString(PyExc_ValueError, "Too many axes");
Py_DECREF(double_dtype);
}
for (i = 0; i < naxis; ++i) {
if (axis[i] < 0) {
axis[i] += PyArray_NDIM(arr);
}
if (axis[i] >= naxis) {
PyErr_SetString(PyExc_ValueError, "axis specified does not exist");
Py_DECREF(double_dtype);
return 0;
}
}
oa_ndim_outer = PyArray_NDIM(arr) - naxis;
oa_ndim_inner = naxis;
i = 0;
j = 0;
k = 0;
for (m = 0; m < PyArray_NDIM(arr); ++m) {
flag = 0;
for (n = 0; n < naxis; ++n) {
if (m == axis[n]) {
flag = 1;
break;
}
}
if (flag) {
arr_axes_inner[i++] = m;
} else {
arr_axes_outer[j++] = m;
out_axes_outer[k] = k;
out_shape[k++] = PyArray_DIM(arr, m);
}
}
out_shape[k] = 2;
out = PyArray_Zeros(++k, out_shape, double_dtype, 0);
if (!out) {
Py_DECREF(double_dtype);
return 0;
}
ops[1] = (PyArrayObject*)out;
/*
for (k = 0; k < PyArray_NDIM(arr); ++k) {
printf("%d %d %d\n", arr_axes_inner[k], arr_axes_outer[k], out_axes_outer[k]);
}
*/
oa_axes_outer[0] = arr_axes_outer;
oa_axes_outer[1] = out_axes_outer;
iter_outer = NpyIter_AdvancedNew(2, ops, outer_flags, NPY_KEEPORDER,
NPY_NO_CASTING, op_flags_outer, NULL,
oa_ndim_outer, oa_axes_outer, NULL, -1);
if (!iter_outer) {
Py_DECREF(double_dtype);
Py_DECREF(out);
return 0;
}
iternext_outer = NpyIter_GetIterNext(iter_outer, NULL);
if (!iternext_outer) {
NpyIter_Deallocate(iter_outer);
Py_DECREF(double_dtype);
Py_DECREF(out);
}
dataptr_outer = NpyIter_GetDataPtrArray(iter_outer);
strideptr_outer = NpyIter_GetInnerStrideArray(iter_outer);
innersizeptr_outer = NpyIter_GetInnerLoopSizePtr(iter_outer);
oa_axes_inner[0] = arr_axes_inner;
iter_inner = NpyIter_AdvancedNew(1, ops, inner_flags, NPY_KEEPORDER,
NPY_UNSAFE_CASTING, op_flags_inner,
dtype_inner, oa_ndim_inner,
oa_axes_inner, NULL, -1);
if (!iter_inner) {
NpyIter_Deallocate(iter_outer);
Py_DECREF(double_dtype);
Py_DECREF(out);
}
iternext_inner = NpyIter_GetIterNext(iter_inner, NULL);
if (!iternext_inner) {
NpyIter_Deallocate(iter_outer);
NpyIter_Deallocate(iter_inner);
Py_DECREF(double_dtype);
Py_DECREF(out);
}
dataptr_inner = NpyIter_GetDataPtrArray(iter_inner);
strideptr_inner = NpyIter_GetInnerStrideArray(iter_inner);
innersizeptr_inner = NpyIter_GetInnerLoopSizePtr(iter_inner);
do {
double *sum_ptr = (double*)dataptr_outer[1];
sum_ptr[0] = 0;
//double *prod_ptr = (double*)(dataptr_outer[1] + strideptr_outer[1]);
// this is clearly cheating...
double *prod_ptr = (double*)(dataptr_outer[1] + PyArray_STRIDE(out, PyArray_NDIM(out)-1));
prod_ptr[0] = 1;
//printf("%ld, %ld\n", (long)sum_ptr, (long)prod_ptr);
NpyIter_ResetBasePointers(iter_inner, dataptr_outer, NULL);
do {
for (k = 0; k < *innersizeptr_inner; ++k) {
const double in = *((double*)(dataptr_inner[0] + k*strideptr_inner[0]));
sum_ptr[0] += in;
prod_ptr[0] *= in;
}
} while(iternext_inner(iter_inner));
} while(iternext_outer(iter_outer));
NpyIter_Deallocate(iter_outer);
NpyIter_Deallocate(iter_inner);
Py_DECREF(double_dtype);
return out;
}
static PyObject* iter_test(PyObject *self, PyObject* args)
{
PyObject *o_arr;
PyObject *axes;
PyObject *out;
PyArrayObject* a_arr;
Py_ssize_t naxes;
Py_ssize_t size = PyTuple_GET_SIZE(args);
int *axis;
int k;
if (size != 2) {
PyErr_SetString(PyExc_TypeError,
"wrong # args, expected 2");
return 0;
}
o_arr = PyTuple_GET_ITEM(args, 0);
if (!PyArray_Check(o_arr)) {
PyErr_SetString(PyExc_ValueError,
"Expected an array");
return 0;
}
a_arr = (PyArrayObject*)o_arr;
axes = PyTuple_GET_ITEM(args, 1);
if (!PyTuple_Check(axes)) {
PyErr_SetString(PyExc_ValueError,
"Expected a tuple");
return 0;
}
// missing some error checking...
naxes = PyTuple_GET_SIZE(axes);
axis = malloc(naxes*sizeof(int));
for (k = 0; k < naxes; ++k) {
axis[k] = PyInt_AsLong(PyTuple_GET_ITEM(axes, k));
}
out = sum_and_prod(a_arr, axis, naxes);
free(axis);
return out;
}
PyMethodDef module_methods[] = {
{"iter_test", &iter_test, METH_VARARGS, ""},
{0} /* sentinel */
};
#if defined(NPY_PY3K)
static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
"ndit",
NULL,
-1,
module_methods,
NULL,
NULL,
NULL,
NULL
};
#endif
#if defined(NPY_PY3K)
PyMODINIT_FUNC PyInit_ndit(void) {
#else
PyMODINIT_FUNC initndit(void) {
#endif
PyObject *m;
import_array();
if (PyErr_Occurred()) {
return;
}
/* Create module */
#if defined(NPY_PY3K)
m = PyModule_Create(&moduledef);
#else
m = Py_InitModule("ndit", module_methods);
#endif
if (!m) {
return;
}
}
from distutils.core import setup, Extension
import numpy as np
ext_modules = [Extension('ndit', sources=['ndit.c'])]
setup(
name = 'ndit',
version = '1.0',
include_dirs = [np.get_include()],
ext_modules = ext_modules
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment