Skip to content

Instantly share code, notes, and snippets.

@syrte
Created August 2, 2021 04:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save syrte/9016ec5484fc4cffc6d8529cdab14142 to your computer and use it in GitHub Desktop.
Save syrte/9016ec5484fc4cffc6d8529cdab14142 to your computer and use it in GitHub Desktop.
Pass a scipy cubic spline to gsl
import numpy as np
from libc.string cimport memcpy
from numpy cimport ndarray
from cython_gsl cimport *
cdef struct gsl_spline_t:
gsl_interp_t *interp
double *x
double *y
size_t size
cdef struct gsl_interp_t:
const gsl_interp_type *type
double xmin
double xmax
size_t size
cspline_state_t *state
cdef struct cspline_state_t:
double *c
double *g
double *diag
double *offdiag
cdef gsl_spline *cspline_sp2gsl(double[:] x, double[:] y, double[:] c=None) nogil:
'''
c = scipy.interpolate.CubicSpline(x, y).c[1]
'''
cdef:
int i, n = len(x)
gsl_spline *spline = gsl_spline_alloc(gsl_interp_cspline, n)
gsl_spline_t *spl
if c is not None:
with gil:
assert c.size == n - 1
spl = <gsl_spline_t *> spline
# need this conversion because cython_gsl does not provide complete struct member info
spl.interp.xmin = x[0]
spl.interp.xmax = x[n - 1]
memcpy(spl.x, &x[0], n * sizeof(double))
memcpy(spl.y, &y[0], n * sizeof(double))
spl.interp.state.c[0] = 0
spl.interp.state.c[n - 1] = 0
memcpy(spl.interp.state.c, &c[0], (n - 1) * sizeof(double))
else:
gsl_spline_init(spline, &x[0], &y[0], n)
return spline
cpdef ndarray gsl_cspline(double[:] xi, double[:] x, double[:] y, double[:] c=None):
'''
Example
-------
x = np.linspace(0, 10, 100)
y = np.sin(x)
f = CubicSpline(x, y, bc_type='not-a-knot')
# f = CubicSpline(x, y, bc_type='natural')
xi = np.linspace(0, 10, 71)
y0 = f(xi)
y1 = pyx.gsl_cspline(xi, x, y, f.c[1])
y2 = pyx.gsl_cspline(xi, x, y)
print(abs(y1 - y0).max(), abs(y2 - y0).max(), abs(y1 - y2).max())
'''
cdef:
int i, n = len(xi)
gsl_spline *spline = cspline_sp2gsl(x, y, c)
double[:] yi = np.zeros(len(xi), dtype='f8')
for i in range(n):
yi[i] = gsl_spline_eval(spline, xi[i], NULL)
gsl_spline_free(spline)
return yi.base
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment