Created
July 27, 2018 21:36
-
-
Save max-nova/056572be1a0126beefc0633793a560e2 to your computer and use it in GitHub Desktop.
Cardano's solution for a depressed cubic
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from numpy.lib.scimath import sqrt as csqrt | |
def get_roots(A, B, C): | |
""" | |
Given a depressed cubic of the form:A(x^3) + Bx + C | |
returns all the roots. | |
Note that inputs need to be floats | |
From Cardano's formula: https://en.wikipedia.org/wiki/Cubic_function | |
""" | |
# first, divide out A to get a cubic in the form: | |
# t^3 + p * t + q = 0 | |
p = B / A | |
q = C / A | |
xi = (-0.5 + 0.5j * 3**0.5) | |
inner_sqrt = csqrt((q**2/4) + (p**3/27)) | |
neg_half_q = -0.5 * q | |
term_1 = np.cbrt(neg_half_q + inner_sqrt) | |
term_2 = np.cbrt(neg_half_q - inner_sqrt) | |
return np.array([xi**k * term_1 + xi**(2*k) * term_2 for k in range(3)]) | |
def get_closest_non_negative_root(A, B, C, current_val): | |
roots = get_roots(A, B, C) | |
roots = np.real_if_close(roots) # snap imag off if close | |
roots[np.imag(roots) != 0] = np.nan # replace imag with nan | |
roots = np.real(roots) # discard imag components | |
roots[roots < 0] = 0 # replace negatives with 0 | |
dist = np.abs(roots - current_val) | |
min_dist = np.nanmin(dist, axis=0) | |
roots[dist != min_dist] = np.nan # null out everything but closest roots | |
# should only be 1 root left, but return max to be safe | |
return np.nanmax(roots, axis=0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment