Skip to content

Instantly share code, notes, and snippets.

@max-nova
Created July 27, 2018 21:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save max-nova/056572be1a0126beefc0633793a560e2 to your computer and use it in GitHub Desktop.
Save max-nova/056572be1a0126beefc0633793a560e2 to your computer and use it in GitHub Desktop.
Cardano's solution for a depressed cubic
import numpy as np
from numpy.lib.scimath import sqrt as csqrt
def get_roots(A, B, C):
"""
Given a depressed cubic of the form:A(x^3) + Bx + C
returns all the roots.
Note that inputs need to be floats
From Cardano's formula: https://en.wikipedia.org/wiki/Cubic_function
"""
# first, divide out A to get a cubic in the form:
# t^3 + p * t + q = 0
p = B / A
q = C / A
xi = (-0.5 + 0.5j * 3**0.5)
inner_sqrt = csqrt((q**2/4) + (p**3/27))
neg_half_q = -0.5 * q
term_1 = np.cbrt(neg_half_q + inner_sqrt)
term_2 = np.cbrt(neg_half_q - inner_sqrt)
return np.array([xi**k * term_1 + xi**(2*k) * term_2 for k in range(3)])
def get_closest_non_negative_root(A, B, C, current_val):
roots = get_roots(A, B, C)
roots = np.real_if_close(roots) # snap imag off if close
roots[np.imag(roots) != 0] = np.nan # replace imag with nan
roots = np.real(roots) # discard imag components
roots[roots < 0] = 0 # replace negatives with 0
dist = np.abs(roots - current_val)
min_dist = np.nanmin(dist, axis=0)
roots[dist != min_dist] = np.nan # null out everything but closest roots
# should only be 1 root left, but return max to be safe
return np.nanmax(roots, axis=0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment