Created
August 2, 2021 04:54
-
-
Save syrte/9016ec5484fc4cffc6d8529cdab14142 to your computer and use it in GitHub Desktop.
Pass a scipy cubic spline to gsl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from libc.string cimport memcpy | |
from numpy cimport ndarray | |
from cython_gsl cimport * | |
cdef struct gsl_spline_t: | |
gsl_interp_t *interp | |
double *x | |
double *y | |
size_t size | |
cdef struct gsl_interp_t: | |
const gsl_interp_type *type | |
double xmin | |
double xmax | |
size_t size | |
cspline_state_t *state | |
cdef struct cspline_state_t: | |
double *c | |
double *g | |
double *diag | |
double *offdiag | |
cdef gsl_spline *cspline_sp2gsl(double[:] x, double[:] y, double[:] c=None) nogil: | |
''' | |
c = scipy.interpolate.CubicSpline(x, y).c[1] | |
''' | |
cdef: | |
int i, n = len(x) | |
gsl_spline *spline = gsl_spline_alloc(gsl_interp_cspline, n) | |
gsl_spline_t *spl | |
if c is not None: | |
with gil: | |
assert c.size == n - 1 | |
spl = <gsl_spline_t *> spline | |
# need this conversion because cython_gsl does not provide complete struct member info | |
spl.interp.xmin = x[0] | |
spl.interp.xmax = x[n - 1] | |
memcpy(spl.x, &x[0], n * sizeof(double)) | |
memcpy(spl.y, &y[0], n * sizeof(double)) | |
spl.interp.state.c[0] = 0 | |
spl.interp.state.c[n - 1] = 0 | |
memcpy(spl.interp.state.c, &c[0], (n - 1) * sizeof(double)) | |
else: | |
gsl_spline_init(spline, &x[0], &y[0], n) | |
return spline | |
cpdef ndarray gsl_cspline(double[:] xi, double[:] x, double[:] y, double[:] c=None): | |
''' | |
Example | |
------- | |
x = np.linspace(0, 10, 100) | |
y = np.sin(x) | |
f = CubicSpline(x, y, bc_type='not-a-knot') | |
# f = CubicSpline(x, y, bc_type='natural') | |
xi = np.linspace(0, 10, 71) | |
y0 = f(xi) | |
y1 = pyx.gsl_cspline(xi, x, y, f.c[1]) | |
y2 = pyx.gsl_cspline(xi, x, y) | |
print(abs(y1 - y0).max(), abs(y2 - y0).max(), abs(y1 - y2).max()) | |
''' | |
cdef: | |
int i, n = len(xi) | |
gsl_spline *spline = cspline_sp2gsl(x, y, c) | |
double[:] yi = np.zeros(len(xi), dtype='f8') | |
for i in range(n): | |
yi[i] = gsl_spline_eval(spline, xi[i], NULL) | |
gsl_spline_free(spline) | |
return yi.base |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment