Skip to content

Instantly share code, notes, and snippets.

@syrte
Created January 21, 2022 18:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save syrte/1ee14d18d520bc4b89126e70bfc10c3a to your computer and use it in GitHub Desktop.
Save syrte/1ee14d18d520bc4b89126e70bfc10c3a to your computer and use it in GitHub Desktop.
class CubicSplineExtrap(CubicSpline):
def __init__(self, x, y, bc_type='not-a-knot', extrapolate='linear'):
"""
Linearly extrapolate outside the range
extrapolate: False, float, 'const', 'linear', 'cubic', or a 2-tuple of them
Example
-------
from scipy.interpolate import PchipInterpolator, CubicSpline
x = np.linspace(-0.7, 1, 11)
a = np.linspace(-1.5, 2, 100)
y = np.sin(x * pi)
f0 = CubicSplineExtrap(x, y, extrapolate=('linear', 'const'))
f1 = CubicSpline(x, y)
f2 = PchipInterpolator(x, y)
plt.figure(figsize=(8, 4))
plt.subplot(121)
plt.scatter(x, y)
for i, f in enumerate([f0, f1, f2]):
plt.plot(a, f(a), ls=['-', '--', ':'][i])
plt.ylim(-2, 2)
plt.subplot(122)
for i, f in enumerate([f0, f1, f2]):
plt.plot(a, f(a, nu=1) / np.pi, ls=['-', '--', ':'][i])
plt.ylim(-2, 2)
"""
if extrapolate is False:
super().__init__(x, y, bc_type=bc_type, extrapolate=False)
else:
super().__init__(x, y, bc_type=bc_type, extrapolate=True)
if np.isscalar(extrapolate):
extrapolate = (extrapolate, extrapolate)
xs, cs = [self.x], [self.c]
for i, ext in enumerate(extrapolate[:2]):
if i == 0:
xi, yi = x[0], y[0]
else:
xi, yi = x[-1], y[-1]
if ext == 'cubic':
continue
elif ext == 'linear':
di = self(xi, nu=1) # derivative at xi
ci = np.array([[0, 0, di, yi]]).T
elif ext == 'const':
ci = np.array([[0, 0, 0, yi]]).T
else:
ci = np.array([[0, 0, 0, float(ext)]]).T
if i == 0:
xs, cs = [xi, *xs], [ci, *cs]
else:
xs, cs = [*xs, xi], [*cs, ci]
if len(xs) > 1:
self.x, self.c = np.hstack(xs), np.hstack(cs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment