Created
November 11, 2016 16:17
-
-
Save mattharrigan/6f678b3d6df5efd236fc23bfb59fd3bd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "Python.h" | |
#include "math.h" | |
#include "numpy/ndarraytypes.h" | |
#include "numpy/ufuncobject.h" | |
#include "numpy/halffloat.h" | |
/* | |
* compute out = in1 + (in2-c)* (in2-c) | |
* c is currently hard coded to 2.0, need to add a setter | |
*/ | |
static PyMethodDef SumSqDiffMethods[] = { | |
{NULL, NULL, 0, NULL} | |
}; | |
static void double_sum_sq_diff(char **args, npy_intp *dimensions, | |
npy_intp* steps, void* data) | |
{ | |
npy_intp i; | |
npy_intp n = dimensions[0]; | |
char *in1 = args[0], *in2 = args[1]; | |
char *out = args[2]; | |
npy_intp in1_step = steps[0], in2_step = steps[1]; | |
npy_intp out_step = steps[2]; | |
double c = *(double *)data; | |
double diff; | |
// specialize for common reduce case | |
if ((in1_step == 0) && (out_step == 0) && (in1 == out)){ | |
for (i = 0; i < n; i++) { | |
diff = *(double *)in2 - c; | |
*((double *)out) += diff * diff; | |
in2 += in2_step; | |
} | |
} | |
else | |
{ | |
for (i = 0; i < n; i++) { | |
diff = *(double *)in2 - c; | |
*((double *)out) = diff * diff + *(double *)in1; | |
in1 += in1_step; | |
in2 += in2_step; | |
out += out_step; | |
} | |
} | |
} | |
/*This a pointer to the above function*/ | |
PyUFuncGenericFunction funcs[1] = {&double_sum_sq_diff}; | |
/* These are the input and return dtypes of sum_sq_diff.*/ | |
static char types[3] = {NPY_DOUBLE, NPY_DOUBLE, | |
NPY_DOUBLE}; | |
static double temp = 2.0; | |
static void *data[1] = {&temp}; | |
#if PY_VERSION_HEX >= 0x03000000 | |
static struct PyModuleDef moduledef = { | |
PyModuleDef_HEAD_INIT, | |
"npufunc", | |
NULL, | |
-1, | |
SumSqDiffMethods, | |
NULL, | |
NULL, | |
NULL, | |
NULL | |
}; | |
PyMODINIT_FUNC PyInit_npufunc(void) | |
{ | |
PyObject *m, *sum_sq_diff, *d; | |
m = PyModule_Create(&moduledef); | |
if (!m) { | |
return NULL; | |
} | |
import_array(); | |
import_umath(); | |
sum_sq_diff = PyUFunc_FromFuncAndData(funcs, data, types, 1, 2, 1, | |
PyUFunc_Zero, "sum_sq_diff", | |
"sum_sq_diff_docstring", 0); | |
d = PyModule_GetDict(m); | |
PyDict_SetItemString(d, "sum_sq_diff", sum_sq_diff); | |
Py_DECREF(sum_sq_diff); | |
return m; | |
} | |
#else | |
PyMODINIT_FUNC initnpufunc(void) | |
{ | |
PyObject *m, *sum_sq_diff, *d; | |
m = Py_InitModule("npufunc", SumSqDiffMethods); | |
if (m == NULL) { | |
return; | |
} | |
import_array(); | |
import_umath(); | |
sum_sq_diff = PyUFunc_FromFuncAndData(funcs, data, types, 1, 2, 1, | |
PyUFunc_Zero, "sum_sq_diff", | |
"sum_sq_diff_docstring", 0); | |
d = PyModule_GetDict(m); | |
PyDict_SetItemString(d, "sum_sq_diff", sum_sq_diff); | |
Py_DECREF(sum_sq_diff); | |
} | |
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
python -m timeit -s "import numpy as np;x=np.linspace(0,1,int(1e7));import npufunc" "npufunc.sum_sq_diff.reduce(x)" | |
python -m timeit -s "import numpy as np;x=np.linspace(0,1,int(1e7))" "np.sum(np.square(x-2.))" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment