Skip to content

Instantly share code, notes, and snippets.

@madmann91
Last active May 15, 2024 13:06
Show Gist options
  • Save madmann91/2ae76df7b49cdd5a11c0f94c1aeb9ee6 to your computer and use it in GitHub Desktop.
Save madmann91/2ae76df7b49cdd5a11c0f94c1aeb9ee6 to your computer and use it in GitHub Desktop.
Solver for quartic and cubic polynomials
#include <stdio.h>
#include <tgmath.h>
#include <stdint.h>
#include <stdbool.h>
#include <assert.h>
#include <string.h>
#define PI 3.141592653589793f
#define SEARCH_RADIUS 1.e-5f // Initial search radius for the local search procedure
static inline float sq(float x) { return x * x; }
static inline float cb(float x) { return x * x * x; }
static inline int solve_linear(float a0, float* z) {
// Solves x + a0 = 0
z[0] = -a0;
return 1;
}
static inline int solve_quadratic(float a1, float a0, float* z) {
// Solves x^2 + a1 * x + a0 = 0
const float d = sq(a1) - 4.f * a0;
if (d < 0.f)
return 0;
const float s = sqrt(d);
z[0] = 0.5f * (-a1 + s);
z[1] = 0.5f * (-a1 - s);
return d == 0.f ? 1 : 2;
}
static inline int solve_cubic(float a2, float a1, float a0, float* z) {
// Solves x^3 + a2 * x^2 + a1 * x + a0 = 0
// Inspired from "Practical Algorithm for Solving the Cubic Equation", D. J. Wolters, 2021
const float q = (3.f * a1 - sq(a2)) / 9.f;
const float r = (9.f * a1 * a2 - 27.f * a0 - 2.f * cb(a2)) / 54.f;
if (sq(r) + cb(q) > 0.f) {
const float a = cbrt(fabs(r) + sqrt(sq(r) + cb(q)));
const float t = a - q / a;
const float t1 = r < 0 ? -t : t;
z[0] = t1 - a2 / 3.f;
return 1;
}
const float theta = q == 0.f ? 0.f : acos(r / sqrt(cb(-q)));
const float phi1 = theta / 3.f;
const float phi2 = phi1 - 2.f * PI / 3.f;
const float phi3 = phi1 + 2.f * PI / 3.f;
const float k1 = 2.f * sqrt(-q);
const float k2 = a2 / 3.f;
z[0] = k1 * cos(phi1) - k2;
z[1] = k1 * cos(phi2) - k2;
z[2] = k1 * cos(phi3) - k2;
return 3;
}
static inline float find_largest_cubic_root(float a2, float a1, float a0) {
float z[3];
int n = solve_cubic(a2, a1, a0, z);
float r = 0.f;
if (n > 0) r = fmax(r, z[0]);
if (n > 1) r = fmax(r, z[1]);
if (n > 2) r = fmax(r, z[2]);
return r;
}
static inline int solve_quartic(float a3, float a2, float a1, float a0, float* z) {
// Solves x^4 + a3 * x^3 + a2 * x^2 + a1 * x + a0 = 0
// Inspired from "Practical Algorithms for Solving the Quartic Equation", D. J. Wolters, 2020
const float c = a3 * 0.25f;
const float b2 = a2 - 6.f * sq(c);
const float b1 = a1 + c * (-2.f * a2 + 8.f * sq(c));
const float b0 = a0 + c * (-a1 + c * (a2 - 3.f * sq(c)));
const float m = find_largest_cubic_root(b2, sq(b2) * 0.25f - b0, sq(b1) * -0.125f);
const float r1 = sqrt(sq(m) + b2 * m + sq(b2) * 0.25f - b0);
const float r = b1 > 0.f ? r1 : -r1;
const float l = sqrt(m * 0.5f);
const float k = m * -0.5f - b2 * 0.5f;
const bool has_k12 = k - r >= 0.f;
const bool has_k34 = k + r >= 0.f;
const float k12 = has_k12 ? sqrt(k - r) : 0.f;
const float k34 = has_k34 ? sqrt(k + r) : 0.f;
const float z1 = l - c + k12;
const float z2 = l - c - k12;
const float z3 = -l - c + k34;
const float z4 = -l - c - k34;
if (has_k12 || has_k34) {
z[0] = has_k12 ? z1 : z3;
z[1] = has_k12 ? z2 : z4;
if (has_k34) {
z[2] = z3;
z[3] = z4;
return has_k12 ? 4 : 2;
}
return 2;
}
return 0;
}
float eval(float a4, float a3, float a2, float a1, float a0, float x) {
return a0 + x * (a1 + x * (a2 + x * (a3 + x * a4)));
}
float eval_diff(float a4, float a3, float a2, float a1, float x) {
return a1 + x * (2.f * a2 + x * (3.f * a3 + x * (4.f * a4)));
}
float local_search(float a4, float a3, float a2, float a1, float a0, float x, size_t max_iters) {
// Finds a zero of q(x) = a4 * x^4 + a3 * x^3 + a2 * x^2 + a1 * x^1 + a0 using a local search
// starting at the given point. This can be used to improve the quality of an initial estimate.
// Find initial bracket around x such that the signs of q(a) and q(b) are different. This
// interval is guaranteed to contain a zero for q(x).
float qx = eval(a4, a3, a2, a1, a0, x);
float a, b, qa, qb;
for (float radius = SEARCH_RADIUS;; radius *= 2.f) {
a = x - radius;
b = x + radius;
qa = eval(a4, a3, a2, a1, a0, a);
qb = eval(a4, a3, a2, a1, a0, b);
// Move x to the point closest to 0
if (fabs(qa) < fabs(qx))
x = a, qx = qa;
if (fabs(qb) < fabs(qx))
x = b, qx = qb;
if (signbit(qa) != signbit(qb))
break;
}
// Tighten the bracket around 0: Use [x, b] or [a, x] instead of [a, b].
if (signbit(qa) == signbit(qx))
a = x, qa = qx;
else
b = x, qb = qx;
// Apply several rounds of bisection or Newton-Raphson, whichever is best
for (size_t iters = 0; iters < max_iters; ++iters) {
float m = fabs(qa) < fabs(qb)
? a - qa / eval_diff(a4, a3, a2, a1, a)
: b - qb / eval_diff(a4, a3, a2, a1, b);
// Use bisection if Newton-Raphson takes us outside the interval
if (m < a || m > b)
m = (a + b) / 2;
const float qm = eval(a4, a3, a2, a1, a0, m);
if (signbit(qm) == signbit(qa))
a = m, qa = qm;
else
b = m, qb = qm;
// Pick the value closest to 0
if (fabs(qm) < fabs(qx))
x = m, qx = qm;
}
return x;
}
int solve(float a4, float a3, float a2, float a1, float a0, float* z) {
// Finds the real roots of the polynomial a4 * x^4 + a3 * x^3 + a2 * x^2 + a1 * x + a0
// Note: a4, a3, a2, a1 and a0 can each be zero
if (a4 == 0.f) {
if (a3 == 0.f) {
if (a2 == 0.f) {
if (a1 == 0.f)
return 0;
return solve_linear(a0 / a1, z);
}
const float inv_a2 = 1.f / a2;
return solve_quadratic(a1 * inv_a2, a0 * inv_a2, z);
}
const float inv_a3 = 1.f / a3;
return solve_cubic(a2 * inv_a3, a1 * inv_a3, a0 * inv_a3, z);
}
const float inv_a4 = 1.f / a4;
return solve_quartic(a3 * inv_a4, a2 * inv_a4, a1 * inv_a4, a0 * inv_a4, z);
}
static inline float next_ulp(float x) {
uint32_t y;
memcpy(&y, &x, sizeof(y));
y++;
memcpy(&x, &y, sizeof(x));
return x;
}
static inline float prev_ulp(float x) {
uint32_t y;
memcpy(&y, &x, sizeof(y));
y--;
memcpy(&x, &y, sizeof(x));
return x;
}
static inline complex float eval_complex(complex float a3, complex float a2, complex float a1, complex float a0, complex float x) {
return a0 + x * (a1 + x * (a2 + x * (a3 + x)));
}
void solve_complex(complex float p[], complex float z[], size_t n, size_t iters) {
// Solve sum(p[i] * x^i, i = 0..n) = 0 with Aberth-Erhlich
static const complex float b = 0.4f + 0.9 * I;
assert(n > 0);
z[0] = 1.f;
for (size_t i = 1; i < n; ++i)
z[i] = z[i - 1] * b;
for (size_t i = 0; i < iters; ++i) {
for (size_t j = 0; j < n; ++j) {
const complex float x = z[j];
complex float y = p[n];
complex float d = p[n] * (float)n;
for (size_t k = n - 1; k > 0; --k) {
y = y * x + p[k];
d = d * x + p[k] * (float)k;
}
y = y * x + p[0];
complex float denom = d / y;
for (size_t k = 0; k < n; ++k) {
if (k == j)
continue;
denom -= 1.f / (z[j] - z[k]);
}
z[j] -= 1.f / denom;
}
}
}
int main() {
#define COEFFS 1.f, -1000.f, 4.f, -40.f
complex float zc[4] = {};
complex float pc[5] = { 1.f, COEFFS };
for (size_t i = 0, n = sizeof(pc) / sizeof(pc[0]); i < n / 2; ++i) {
complex float p = pc[i];
pc[i] = pc[n - i - 1];
pc[n - i - 1] = p;
}
solve_complex(pc, zc, 4, 8);
for (size_t i = 0; i < 4; ++i) {
complex float y = eval_complex(COEFFS, zc[i]);
printf("p(%f + %fi) = %f + %fi\n", creal(zc[i]), cimag(zc[i]), creal(y), cimag(y));
}
printf("----\n");
float z[4] = {};
int n = solve(1.f, COEFFS, z);
for (size_t i = 0; i < n; ++i) {
z[i] = local_search(1.f, COEFFS, z[i], 4);
printf("p(%f) = %f\n", z[i], eval(1.f, COEFFS, z[i]));
}
printf("----\n");
for (size_t i = 0; i < n; ++i) {
float ref = fabs(eval(1.f, COEFFS, z[i]));
float next = fabs(eval(1.f, COEFFS, next_ulp(z[i])));
float prev = fabs(eval(1.f, COEFFS, prev_ulp(z[i])));
printf("|p(%f)| = %f", z[i], ref);
if (next >= ref && prev >= ref)
printf(" IS OPTIMAL\n");
else {
printf(" IS SUB-OPTIMAL:\n");
if (prev < ref)
printf("|p(%f)| = %f\n", prev_ulp(z[i]), prev);
if (next < ref)
printf("|p(%f)| = %f\n", next_ulp(z[i]), next);
}
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment