Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
#include "Python.h"
#include "math.h"
#include "numpy/ndarraytypes.h"
#include "numpy/ufuncobject.h"
#include "numpy/halffloat.h"
/*
* compute out = in1 + (in2-c)* (in2-c)
* c is currently hard coded to 2.0, need to add a setter
*/
static PyMethodDef SumSqDiffMethods[] = {
{NULL, NULL, 0, NULL}
};
static void double_sum_sq_diff(char **args, npy_intp *dimensions,
npy_intp* steps, void* data)
{
npy_intp i;
npy_intp n = dimensions[0];
char *in1 = args[0], *in2 = args[1];
char *out = args[2];
npy_intp in1_step = steps[0], in2_step = steps[1];
npy_intp out_step = steps[2];
double c = *(double *)data;
double diff;
// specialize for common reduce case
if ((in1_step == 0) && (out_step == 0) && (in1 == out)){
for (i = 0; i < n; i++) {
diff = *(double *)in2 - c;
*((double *)out) += diff * diff;
in2 += in2_step;
}
}
else
{
for (i = 0; i < n; i++) {
diff = *(double *)in2 - c;
*((double *)out) = diff * diff + *(double *)in1;
in1 += in1_step;
in2 += in2_step;
out += out_step;
}
}
}
/*This a pointer to the above function*/
PyUFuncGenericFunction funcs[1] = {&double_sum_sq_diff};
/* These are the input and return dtypes of sum_sq_diff.*/
static char types[3] = {NPY_DOUBLE, NPY_DOUBLE,
NPY_DOUBLE};
static double temp = 2.0;
static void *data[1] = {&temp};
#if PY_VERSION_HEX >= 0x03000000
static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
"npufunc",
NULL,
-1,
SumSqDiffMethods,
NULL,
NULL,
NULL,
NULL
};
PyMODINIT_FUNC PyInit_npufunc(void)
{
PyObject *m, *sum_sq_diff, *d;
m = PyModule_Create(&moduledef);
if (!m) {
return NULL;
}
import_array();
import_umath();
sum_sq_diff = PyUFunc_FromFuncAndData(funcs, data, types, 1, 2, 1,
PyUFunc_Zero, "sum_sq_diff",
"sum_sq_diff_docstring", 0);
d = PyModule_GetDict(m);
PyDict_SetItemString(d, "sum_sq_diff", sum_sq_diff);
Py_DECREF(sum_sq_diff);
return m;
}
#else
PyMODINIT_FUNC initnpufunc(void)
{
PyObject *m, *sum_sq_diff, *d;
m = Py_InitModule("npufunc", SumSqDiffMethods);
if (m == NULL) {
return;
}
import_array();
import_umath();
sum_sq_diff = PyUFunc_FromFuncAndData(funcs, data, types, 1, 2, 1,
PyUFunc_Zero, "sum_sq_diff",
"sum_sq_diff_docstring", 0);
d = PyModule_GetDict(m);
PyDict_SetItemString(d, "sum_sq_diff", sum_sq_diff);
Py_DECREF(sum_sq_diff);
}
#endif
python -m timeit -s "import numpy as np;x=np.linspace(0,1,int(1e7));import npufunc" "npufunc.sum_sq_diff.reduce(x)"
python -m timeit -s "import numpy as np;x=np.linspace(0,1,int(1e7))" "np.sum(np.square(x-2.))"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment