Skip to content

Instantly share code, notes, and snippets.

@zachjweiner
Created March 4, 2023 19:38
Show Gist options
  • Save zachjweiner/3e62a7392b3e3981a858a752e0595712 to your computer and use it in GitHub Desktop.
Save zachjweiner/3e62a7392b3e3981a858a752e0595712 to your computer and use it in GitHub Desktop.
Example implementation of callbacks to control step acceptance in a scipy.integrate.solve_ivp method
# this file contains code adapted from SciPy, namely, from
# https://github.com/scipy/scipy/blob/main/scipy/integrate/_ivp/rk.py
# whose copyright and license is specified in
# https://github.com/scipy/scipy/blob/main/LICENSE.txt
import numpy as np
from scipy.integrate._ivp.rk import rk_step, SAFETY, MIN_FACTOR, MAX_FACTOR
from scipy.integrate._ivp.rk import DOP853 as DOP853_
class DOP853(DOP853_):
def __init__(self, *args, **kwargs):
self.callback = kwargs.pop("callback", None)
super().__init__(*args, **kwargs)
def _step_impl(self):
t = self.t
y = self.y
max_step = self.max_step
rtol = self.rtol
atol = self.atol
min_step = 10 * np.abs(np.nextafter(t, self.direction * np.inf) - t)
if self.h_abs > max_step:
h_abs = max_step
elif self.h_abs < min_step:
h_abs = min_step
else:
h_abs = self.h_abs
step_accepted = False
step_rejected = False
callback_accepted_step = True
while not step_accepted:
if h_abs < min_step:
return False, self.TOO_SMALL_STEP
h = h_abs * self.direction
t_new = t + h
if self.direction * (t_new - self.t_bound) > 0:
t_new = self.t_bound
h = t_new - t
h_abs = np.abs(h)
y_new, f_new = rk_step(self.fun, t, y, self.f, h, self.A,
self.B, self.C, self.K)
if self.callback is not None:
callback_accepted_step = self.callback(t + h, y_new, f_new)
scale = atol + np.maximum(np.abs(y), np.abs(y_new)) * rtol
error_norm = self._estimate_error_norm(self.K, h, scale)
if error_norm < 1 and callback_accepted_step:
if error_norm == 0:
factor = MAX_FACTOR
else:
factor = min(MAX_FACTOR,
SAFETY * error_norm ** self.error_exponent)
if step_rejected:
factor = min(1, factor)
h_abs *= factor
step_accepted = True
elif callback_accepted_step:
h_abs *= max(MIN_FACTOR,
SAFETY * error_norm ** self.error_exponent)
step_rejected = True
else:
# if (not callback_accepted_step), restart without modifying the step
# counting on the callback to make any required changes
pass
self.h_previous = h
self.y_old = y
self.t = t_new
self.y = y_new
self.h_abs = h_abs
self.f = f_new
return True, None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment