Created
March 19, 2015 13:27
-
-
Save ewmoore/e88ee9dd84c1d9d58892 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#define NPY_NO_DEPRECATED_API NPY_API_VERSION | |
#include <Python.h> | |
#include <numpy/arrayobject.h> | |
#include "numpy/npy_3kcompat.h" | |
PyObject* sum_and_prod(PyArrayObject *arr, int *axis, int naxis) | |
{ | |
NpyIter *iter_outer, *iter_inner; | |
NpyIter_IterNextFunc *iternext_outer, *iternext_inner; | |
char **dataptr_outer, **dataptr_inner; | |
npy_intp *strideptr_outer, *strideptr_inner; | |
npy_intp *innersizeptr_outer, *innersizeptr_inner; | |
PyArrayObject *ops[2]; | |
PyObject *out; | |
npy_uint32 outer_flags = 0; | |
npy_uint32 op_flags_outer[2]; | |
int oa_ndim_outer; | |
int *oa_axes_outer[2]; | |
int arr_axes_outer[NPY_MAXDIMS]; | |
int out_axes_outer[NPY_MAXDIMS]; | |
//PyArrayDescr *dtype_outer[2]; | |
npy_intp out_shape[NPY_MAXDIMS]; | |
int i,j,k,m,n,flag; | |
npy_uint32 inner_flags = NPY_ITER_EXTERNAL_LOOP | NPY_ITER_BUFFERED | NPY_ITER_GROWINNER; | |
npy_uint32 op_flags_inner[1]; | |
int oa_ndim_inner; | |
int *oa_axes_inner[1]; | |
int arr_axes_inner[NPY_MAXDIMS]; | |
PyArray_Descr *dtype_inner[1]; | |
PyArray_Descr *double_dtype = PyArray_DescrFromType(NPY_DOUBLE); | |
ops[0] = arr; | |
op_flags_outer[0] = NPY_ITER_READONLY; | |
op_flags_inner[0] = NPY_ITER_READONLY | NPY_ITER_NBO; | |
//dtype_outer[0] = PyArray_DESCR(arr); | |
dtype_inner[0] = double_dtype; | |
op_flags_outer[1] = NPY_ITER_WRITEONLY; | |
//dtype_outer[1] = double_dtype; | |
if (naxis <= 0) { | |
PyErr_SetString(PyExc_ValueError, "Need at at least 1 axis"); | |
Py_DECREF(double_dtype); | |
} | |
if (naxis > PyArray_NDIM(arr)) { | |
PyErr_SetString(PyExc_ValueError, "Too many axes"); | |
Py_DECREF(double_dtype); | |
} | |
for (i = 0; i < naxis; ++i) { | |
if (axis[i] < 0) { | |
axis[i] += PyArray_NDIM(arr); | |
} | |
if (axis[i] >= naxis) { | |
PyErr_SetString(PyExc_ValueError, "axis specified does not exist"); | |
Py_DECREF(double_dtype); | |
return 0; | |
} | |
} | |
oa_ndim_outer = PyArray_NDIM(arr) - naxis; | |
oa_ndim_inner = naxis; | |
i = 0; | |
j = 0; | |
k = 0; | |
for (m = 0; m < PyArray_NDIM(arr); ++m) { | |
flag = 0; | |
for (n = 0; n < naxis; ++n) { | |
if (m == axis[n]) { | |
flag = 1; | |
break; | |
} | |
} | |
if (flag) { | |
arr_axes_inner[i++] = m; | |
} else { | |
arr_axes_outer[j++] = m; | |
out_axes_outer[k] = k; | |
out_shape[k++] = PyArray_DIM(arr, m); | |
} | |
} | |
out_shape[k] = 2; | |
out = PyArray_Zeros(++k, out_shape, double_dtype, 0); | |
if (!out) { | |
Py_DECREF(double_dtype); | |
return 0; | |
} | |
ops[1] = (PyArrayObject*)out; | |
/* | |
for (k = 0; k < PyArray_NDIM(arr); ++k) { | |
printf("%d %d %d\n", arr_axes_inner[k], arr_axes_outer[k], out_axes_outer[k]); | |
} | |
*/ | |
oa_axes_outer[0] = arr_axes_outer; | |
oa_axes_outer[1] = out_axes_outer; | |
iter_outer = NpyIter_AdvancedNew(2, ops, outer_flags, NPY_KEEPORDER, | |
NPY_NO_CASTING, op_flags_outer, NULL, | |
oa_ndim_outer, oa_axes_outer, NULL, -1); | |
if (!iter_outer) { | |
Py_DECREF(double_dtype); | |
Py_DECREF(out); | |
return 0; | |
} | |
iternext_outer = NpyIter_GetIterNext(iter_outer, NULL); | |
if (!iternext_outer) { | |
NpyIter_Deallocate(iter_outer); | |
Py_DECREF(double_dtype); | |
Py_DECREF(out); | |
} | |
dataptr_outer = NpyIter_GetDataPtrArray(iter_outer); | |
strideptr_outer = NpyIter_GetInnerStrideArray(iter_outer); | |
innersizeptr_outer = NpyIter_GetInnerLoopSizePtr(iter_outer); | |
oa_axes_inner[0] = arr_axes_inner; | |
iter_inner = NpyIter_AdvancedNew(1, ops, inner_flags, NPY_KEEPORDER, | |
NPY_UNSAFE_CASTING, op_flags_inner, | |
dtype_inner, oa_ndim_inner, | |
oa_axes_inner, NULL, -1); | |
if (!iter_inner) { | |
NpyIter_Deallocate(iter_outer); | |
Py_DECREF(double_dtype); | |
Py_DECREF(out); | |
} | |
iternext_inner = NpyIter_GetIterNext(iter_inner, NULL); | |
if (!iternext_inner) { | |
NpyIter_Deallocate(iter_outer); | |
NpyIter_Deallocate(iter_inner); | |
Py_DECREF(double_dtype); | |
Py_DECREF(out); | |
} | |
dataptr_inner = NpyIter_GetDataPtrArray(iter_inner); | |
strideptr_inner = NpyIter_GetInnerStrideArray(iter_inner); | |
innersizeptr_inner = NpyIter_GetInnerLoopSizePtr(iter_inner); | |
do { | |
double *sum_ptr = (double*)dataptr_outer[1]; | |
sum_ptr[0] = 0; | |
//double *prod_ptr = (double*)(dataptr_outer[1] + strideptr_outer[1]); | |
// this is clearly cheating... | |
double *prod_ptr = (double*)(dataptr_outer[1] + PyArray_STRIDE(out, PyArray_NDIM(out)-1)); | |
prod_ptr[0] = 1; | |
//printf("%ld, %ld\n", (long)sum_ptr, (long)prod_ptr); | |
NpyIter_ResetBasePointers(iter_inner, dataptr_outer, NULL); | |
do { | |
for (k = 0; k < *innersizeptr_inner; ++k) { | |
const double in = *((double*)(dataptr_inner[0] + k*strideptr_inner[0])); | |
sum_ptr[0] += in; | |
prod_ptr[0] *= in; | |
} | |
} while(iternext_inner(iter_inner)); | |
} while(iternext_outer(iter_outer)); | |
NpyIter_Deallocate(iter_outer); | |
NpyIter_Deallocate(iter_inner); | |
Py_DECREF(double_dtype); | |
return out; | |
} | |
static PyObject* iter_test(PyObject *self, PyObject* args) | |
{ | |
PyObject *o_arr; | |
PyObject *axes; | |
PyObject *out; | |
PyArrayObject* a_arr; | |
Py_ssize_t naxes; | |
Py_ssize_t size = PyTuple_GET_SIZE(args); | |
int *axis; | |
int k; | |
if (size != 2) { | |
PyErr_SetString(PyExc_TypeError, | |
"wrong # args, expected 2"); | |
return 0; | |
} | |
o_arr = PyTuple_GET_ITEM(args, 0); | |
if (!PyArray_Check(o_arr)) { | |
PyErr_SetString(PyExc_ValueError, | |
"Expected an array"); | |
return 0; | |
} | |
a_arr = (PyArrayObject*)o_arr; | |
axes = PyTuple_GET_ITEM(args, 1); | |
if (!PyTuple_Check(axes)) { | |
PyErr_SetString(PyExc_ValueError, | |
"Expected a tuple"); | |
return 0; | |
} | |
// missing some error checking... | |
naxes = PyTuple_GET_SIZE(axes); | |
axis = malloc(naxes*sizeof(int)); | |
for (k = 0; k < naxes; ++k) { | |
axis[k] = PyInt_AsLong(PyTuple_GET_ITEM(axes, k)); | |
} | |
out = sum_and_prod(a_arr, axis, naxes); | |
free(axis); | |
return out; | |
} | |
PyMethodDef module_methods[] = { | |
{"iter_test", &iter_test, METH_VARARGS, ""}, | |
{0} /* sentinel */ | |
}; | |
#if defined(NPY_PY3K) | |
static struct PyModuleDef moduledef = { | |
PyModuleDef_HEAD_INIT, | |
"ndit", | |
NULL, | |
-1, | |
module_methods, | |
NULL, | |
NULL, | |
NULL, | |
NULL | |
}; | |
#endif | |
#if defined(NPY_PY3K) | |
PyMODINIT_FUNC PyInit_ndit(void) { | |
#else | |
PyMODINIT_FUNC initndit(void) { | |
#endif | |
PyObject *m; | |
import_array(); | |
if (PyErr_Occurred()) { | |
return; | |
} | |
/* Create module */ | |
#if defined(NPY_PY3K) | |
m = PyModule_Create(&moduledef); | |
#else | |
m = Py_InitModule("ndit", module_methods); | |
#endif | |
if (!m) { | |
return; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from distutils.core import setup, Extension | |
import numpy as np | |
ext_modules = [Extension('ndit', sources=['ndit.c'])] | |
setup( | |
name = 'ndit', | |
version = '1.0', | |
include_dirs = [np.get_include()], | |
ext_modules = ext_modules | |
) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment