Created
April 25, 2019 22:07
-
-
Save mobiusklein/808d101ea228a71af02e9b5f4a8a0caa to your computer and use it in GitHub Desktop.
`implement_array_function method already has a docstring` Minimal Reproducible Example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
cimport cython | |
import numpy as np | |
cimport numpy as np | |
np.import_array() | |
ctypedef cython.floating floating_t | |
ctypedef fused numeric_collection: | |
np.ndarray | |
list | |
tuple | |
object | |
cdef object double_dtype = np.float64 | |
cdef np.ndarray[double] coerce_data(numeric_collection data): | |
cdef np.ndarray npdata | |
if numeric_collection is object: | |
return np.array(list(data), dtype=np.float64) | |
elif numeric_collection is list or numeric_collection is tuple: | |
return np.array(data, dtype=np.float64) | |
elif numeric_collection is np.ndarray: | |
npdata = data | |
if npdata.dtype != double_dtype: | |
return npdata.astype(double_dtype) | |
else: | |
return npdata | |
def make_array(numeric_collection data): | |
return coerce_data(data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from setuptools import setup, Extension | |
import numpy as np | |
from Cython.Build import cythonize | |
ext_modules = [ | |
Extension( | |
"numfoo", ["numfoo.pyx"], | |
language='c++', | |
include_dirs=[np.get_include()] | |
) | |
] | |
ext_modules = cythonize(ext_modules) | |
setup(name="numfoo", ext_modules=ext_modules) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numfoo | |
import numpy as np | |
data = [1.2, 2.4] | |
# The data are a plain list | |
print(data) | |
# Invoking the fused function explicitly with a list works | |
print(numfoo.make_array[list](data)) | |
# convert it into a NumPy array | |
data = np.array(data) | |
# Invoking the fused function explicitly with an np.ndarray works | |
print(numfoo.make_array[np.ndarray](data)) | |
# This fails | |
print(numfoo.make_array(data)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment