Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Numpy MultiIterator
/*NUMPY_API
* Get MultiIterator from array of Python objects and any additional
*
* PyObject **mps -- array of PyObjects
* int n - number of PyObjects in the array
* int nadd - number of additional arrays to include in the iterator.
*
* Returns a multi-iterator object.
*/
NPY_NO_EXPORT PyObject *
PyArray_MultiIterFromObjects(PyObject **mps, int n, int nadd, ...)
{
va_list va;
PyArrayMultiIterObject *multi;
PyObject *current;
PyObject *arr;
int i, ntot, err=0;
ntot = n + nadd;
if (ntot < 1 || ntot > NPY_MAXARGS) {
PyErr_Format(PyExc_ValueError,
"Need at least 1 and at most %d "
"array objects.", NPY_MAXARGS);
return NULL;
}
multi = PyArray_malloc(sizeof(PyArrayMultiIterObject));
if (multi == NULL) {
return PyErr_NoMemory();
}
PyObject_Init((PyObject *)multi, &PyArrayMultiIter_Type);
for (i = 0; i < ntot; i++) {
multi->iters[i] = NULL;
}
multi->numiter = ntot;
multi->index = 0;
va_start(va, nadd);
for (i = 0; i < ntot; i++) {
if (i < n) {
current = mps[i];
}
else {
current = va_arg(va, PyObject *);
}
arr = PyArray_FROM_O(current);
if (arr == NULL) {
err = 1;
break;
}
else {
multi->iters[i] = (PyArrayIterObject *)PyArray_IterNew(arr);
if (multi->iters[i] == NULL) {
err = 1;
break;
}
Py_DECREF(arr);
}
}
va_end(va);
if (!err && PyArray_Broadcast(multi) < 0) {
err = 1;
}
if (err) {
Py_DECREF(multi);
return NULL;
}
PyArray_MultiIter_RESET(multi);
return (PyObject *)multi;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment