Skip to content

Instantly share code, notes, and snippets.

@maxiimilian
Last active May 27, 2024 13:27
Show Gist options
  • Save maxiimilian/67113eb1d60a5d8ceca212fbcad100c9 to your computer and use it in GitHub Desktop.
Save maxiimilian/67113eb1d60a5d8ceca212fbcad100c9 to your computer and use it in GitHub Desktop.
Function to look for all roots within given interval/bracket based on scipy's `root_scalar`. Resolution `n` of this function only needs to be high enough so that all sign changes of roots are covered.
import warnings
from typing import Callable, Iterable
import numpy as np
from scipy.optimize import root_scalar
def multi_root(f: Callable, bracket: Iterable[float], args: Iterable = (), n: int = 30) -> np.ndarray:
""" Find all roots of f in `bracket`, given that resolution `n` covers the sign change.
Fine-grained root finding is performed with `scipy.optimize.root_scalar`.
Parameters
----------
f: Callable
Function to be evaluated
bracket: Sequence of two floats
Specifies interval within which roots are searched.
args: Iterable, optional
Iterable passed to `f` for evaluation
n: int
Number of points sampled equidistantly from bracket to evaluate `f`.
Resolution has to be high enough to cover sign changes of all roots but not finer than that.
Actual roots are found using `scipy.optimize.root_scalar`.
Returns
-------
roots: np.ndarray
Array containing all unique roots that were found in `bracket`.
"""
# Evaluate function in given bracket
x = np.linspace(*bracket, n)
y = f(x, *args)
# Find where adjacent signs are not equal
sign_changes = np.where(np.sign(y[:-1]) != np.sign(y[1:]))[0]
# Find roots around sign changes
root_finders = (
root_scalar(
f=f,
args=args,
bracket=(x[s], x[s+1])
)
for s in sign_changes
)
roots = np.array([
r.root if r.converged else np.nan
for r in root_finders
])
if np.any(np.isnan(roots)):
warnings.warn("Not all root finders converged for estimated brackets! Maybe increase resolution `n`.")
roots = roots[~np.isnan(roots)]
roots_unique = np.unique(roots)
if len(roots_unique) != len(roots):
warnings.warn("One root was found multiple times. "
"Try to increase or decrease resolution `n` to see if this warning disappears.")
return roots_unique
if __name__ == '__main__':
def poly1(x):
return (x+4)*(x+2)*(x-1)*(x-5)
roots = multi_root(poly1, [-5, 6])
@defencedog
Copy link

Extremely thanks

@LeRobertDePoche
Copy link

Extremely thanks
+1 !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment