Last active
January 1, 2017 18:25
-
-
Save se4u/e44f631b249e0be03c21c6c898059176 to your computer and use it in GitHub Desktop.
Add support for float data type to numpy.random.randn call
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/mtrand/mtrand.pyx b/mtrand/mtrand.pyx | |
index 65c1954..79d0d70 100755 | |
--- a/mtrand/mtrand.pyx | |
+++ b/mtrand/mtrand.pyx | |
@@ -71,6 +71,7 @@ cdef extern from "randomkit.h": | |
rk_error rk_altfill(void *buffer, size_t size, int strong, | |
rk_state *state) nogil | |
double rk_gauss(rk_state *state) nogil | |
+ float rk_gauss_32bit(rk_state *state) nogil | |
void rk_random_uint64(npy_uint64 off, npy_uint64 rng, npy_intp cnt, | |
npy_uint64 *out, rk_state *state) nogil | |
void rk_random_uint32(npy_uint32 off, npy_uint32 rng, npy_intp cnt, | |
@@ -124,6 +125,7 @@ cdef extern from "distributions.h": | |
long rk_logseries(rk_state *state, double p) nogil | |
ctypedef double (* rk_cont0)(rk_state *state) nogil | |
+ctypedef float (* rk_cont0_32bit)(rk_state *state) nogil | |
ctypedef double (* rk_cont1)(rk_state *state, double a) nogil | |
ctypedef double (* rk_cont2)(rk_state *state, double a, double b) nogil | |
ctypedef double (* rk_cont3)(rk_state *state, double a, double b, double c) nogil | |
@@ -172,6 +174,26 @@ cdef object cont0_array(rk_state *state, rk_cont0 func, object size, | |
array_data[i] = func(state) | |
return array | |
+cdef object cont0_array_32bit(rk_state *state, rk_cont0_32bit func, object size, | |
+ object lock): | |
+ cdef float *array_data | |
+ cdef ndarray array "arrayObject" | |
+ cdef npy_intp length | |
+ cdef npy_intp i | |
+ cdef float rv | |
+ if size is None: | |
+ with lock, nogil: | |
+ rv = func(state) | |
+ return rv | |
+ else: | |
+ array = <ndarray>np.empty(size, np.float32) | |
+ length = PyArray_SIZE(array) | |
+ array_data = <float *>PyArray_DATA(array) | |
+ with lock, nogil: | |
+ for i from 0 <= i < length: | |
+ array_data[i] = func(state) | |
+ return array | |
+ | |
cdef object cont1_array_sc(rk_state *state, rk_cont1 func, object size, double a, | |
object lock): | |
@@ -1500,7 +1522,7 @@ cdef class RandomState: | |
# Complicated, continuous distributions: | |
- def standard_normal(self, size=None): | |
+ def standard_normal(self, size=None, dtype=np.double): | |
""" | |
standard_normal(size=None) | |
@@ -1531,7 +1553,10 @@ cdef class RandomState: | |
(3, 4, 2) | |
""" | |
- return cont0_array(self.internal_state, rk_gauss, size, self.lock) | |
+ if dtype==np.double: | |
+ return cont0_array(self.internal_state, rk_gauss, size, self.lock) | |
+ else: | |
+ return cont0_array_32bit(self.internal_state, rk_gauss_32bit, size, self.lock) | |
def normal(self, loc=0.0, scale=1.0, size=None): | |
""" | |
diff --git a/mtrand/randomkit.c b/mtrand/randomkit.c | |
index 3a95efe..fcc20c9 100755 | |
--- a/mtrand/randomkit.c | |
+++ b/mtrand/randomkit.c | |
@@ -622,3 +622,9 @@ rk_gauss(rk_state *state) | |
return f*x2; | |
} | |
} | |
+ | |
+float | |
+rk_gauss_32bit(rk_state *state) | |
+{ | |
+ return (float)rk_gauss(state); | |
+} | |
diff --git a/mtrand/randomkit.h b/mtrand/randomkit.h | |
index fcdd606..05d2f5f 100755 | |
--- a/mtrand/randomkit.h | |
+++ b/mtrand/randomkit.h | |
@@ -218,6 +218,7 @@ extern rk_error rk_altfill(void *buffer, size_t size, int strong, | |
* return a random gaussian deviate with variance unity and zero mean. | |
*/ | |
extern double rk_gauss(rk_state *state); | |
+extern float rk_gauss_32bit(rk_state *state); | |
#ifdef __cplusplus | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment