Skip to content

Instantly share code, notes, and snippets.

@jcrist
Created April 3, 2020 00:08
Show Gist options
  • Save jcrist/069154f8bd7bbc1c1cd81a926c533fa4 to your computer and use it in GitHub Desktop.
Save jcrist/069154f8bd7bbc1c1cd81a926c533fa4 to your computer and use it in GitHub Desktop.
Overload fresnel
import ctypes
import numpy as np
from numba.extending import get_cython_function_address, overload, intrinsic
from numba.errors import TypingError
from numba import cgutils, types
import scipy.special
@intrinsic
def ref(typingctx, data):
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder, args[0])
return ptr
sig = types.CPointer(data)(data)
return sig, impl
@intrinsic
def deref(typingctx, data):
def impl(context, builder, signature, args):
val = builder.load(args[0])
return val
sig = data.dtype(data)
return sig, impl
@overload(scipy.special.fresnel)
def overload_fresnel(z):
c_double = ctypes.c_double
c_double_p = ctypes.POINTER(c_double)
cephes_fresnel = ctypes.CFUNCTYPE(None, c_double, c_double_p, c_double_p)(
get_cython_function_address("scipy.special.cython_special", "__pyx_fuse_1fresnel")
)
if isinstance(z, types.Array):
if not isinstance(z.dtype, (types.Integer, types.Float)):
raise TypingError("z must be either integer or float")
if z.ndim != 1:
raise TypingError("only implemented for 1 dimension")
def fresnel(z):
S = np.empty_like(z)
C = np.empty_like(z)
for i in range(z.shape[0]):
cephes_fresnel(z[i], S[i:].ctypes, C[i:].ctypes)
return S, C
else:
if not isinstance(z, (types.Integer, types.Float)):
raise TypingError("z must be either integer or float")
def fresnel(z):
S = ref(np.float64(0.0))
C = ref(np.float64(0.0))
cephes_fresnel(z, S, C)
return deref(S), deref(C)
return fresnel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment