Instantly share code, notes, and snippets.

# dhermes/.gitignore Last active Feb 7, 2017

Remez Algorithm for log(x)
 *.pyc *.aux *.log *.out
 from __future__ import print_function import matplotlib.pyplot as plt import mpmath import numpy as np import seaborn # NOTE: Created via: seaborn.husl_palette(6)[4] HUSL_BLUE = (0.23299120924703914, 0.639586552066035, 0.9260706093977744) LOG2_HI = float.fromhex('0x1.62e42fee00000p-1') LOG2_LOW = float.fromhex('0x1.a39ef35793c76p-33') L1 = float.fromhex('0x1.5555555555593p-1') L2 = float.fromhex('0x1.999999997fa04p-2') L3 = float.fromhex('0x1.2492494229359p-2') L4 = float.fromhex('0x1.c71c51d8e78afp-3') L5 = float.fromhex('0x1.7466496cb03dep-3') L6 = float.fromhex('0x1.39a09d078c69fp-3') L7 = float.fromhex('0x1.2f112df3e5244p-3') SQRT2_HALF = float.fromhex('0x1.6a09e667f3bcdp-1') CTX = mpmath.MPContext() CTX.prec = 200 # Bits vs. default of 53 def log_mpf(x): return CTX.log(CTX.mpf(x)) def log_ieee754(x): f1, ki = np.frexp(x) if f1 < SQRT2_HALF: f1 *= 2 ki -= 1 f = f1 - 1 k = float(ki) s = f / (2 + f) s2 = s * s s4 = s2 * s2 # Terms with odd powers of s^2. t1 = s2 * (L1 + s4 * (L3 + s4 * (L5 + s4 * L7))) # Terms with even powers of s^2. t2 = s4 * (L2 + s4 * (L4 + s4 * L6)) R = t1 + t2 hfsq = 0.5 * f * f return k * LOG2_HI - ((hfsq - (s * (hfsq + R) + k * LOG2_LOW)) - f) def main(): num_points = 2**14 x_vals = np.linspace(0, np.exp(8.0), num_points) x_vals = x_vals[1:] # Eliminate 0.0 since log(0) = -inf. log_rel_errors = [] for x_val in x_vals: log_val = log_ieee754(x_val) log_hp_val = log_mpf(x_val) log_rel_errors.append(abs(log_val - log_hp_val) / abs(log_hp_val)) plt.plot(x_vals, log_rel_errors, color=HUSL_BLUE) filename = 'log_high_precision_relative_error.png' plt.savefig(filename, bbox_inches='tight') print('Saved ' + filename) if __name__ == '__main__': main()
 from __future__ import print_function import itertools import matplotlib.pyplot as plt import mpmath import seaborn from remez_w_high_precision import get_peaks # NOTE: Created via: seaborn.husl_palette(6)[4] HUSL_BLUE = (0.23299120924703914, 0.639586552066035, 0.9260706093977744) CTX = mpmath.MPContext() CTX.prec = 200 # Bits vs. default of 53 L1 = CTX.mpf(float.fromhex('0x1.5555555555593p-1')) L2 = CTX.mpf(float.fromhex('0x1.999999997fa04p-2')) L3 = CTX.mpf(float.fromhex('0x1.2492494229359p-2')) L4 = CTX.mpf(float.fromhex('0x1.c71c51d8e78afp-3')) L5 = CTX.mpf(float.fromhex('0x1.7466496cb03dep-3')) L6 = CTX.mpf(float.fromhex('0x1.39a09d078c69fp-3')) L7 = CTX.mpf(float.fromhex('0x1.2f112df3e5244p-3')) MAX_S = 3 - 2 * CTX.sqrt(2) EXTRA_S = CTX.mpf('0.1717') def f(s): # (L1 * s**2 + L2 * s**4 + L3 * s**6 + L4 * s**8 + # L5 * s**10 + L6 * s**12 + L7 * s**14) # Computes f(s) the exact same way is done in math/log.go. s2 = s * s s4 = s2 * s2 t1 = s2 * (L1 + s4 * (L3 + s4 * (L5 + s4 * L7))) t2 = s4 * (L2 + s4 * (L4 + s4 * L6)) return t1 + t2 def R(s): if s == CTX.mpf(0.0): return CTX.mpf(0.0) else: numer = CTX.log(1 + s) - CTX.log(1 - s) return numer / s - CTX.mpf(2.0) def err_func(s): return R(s) - f(s) def err_func_diff(s): return CTX.diff(err_func, s) def main(): num_points = 2**10 extra_points = 32 # Extend past the true max to find an 8th point where # equi-oscillation occurs. s_vals = (CTX.linspace(0, MAX_S, num_points) + CTX.linspace(MAX_S, EXTRA_S, extra_points)[1:]) delta_vals = [err_func(s_val) for s_val in s_vals] abs_vals_iter = itertools.imap(abs, delta_vals) exponent = CTX.log(max(abs_vals_iter), 2) print(' Number of points: % d' % (num_points,)) print(' CTX.prec: % d' % (CTX.prec,)) print('Exponent of max error: % 2.3f' % (exponent,)) peaks = sorted(get_peaks(s_vals, delta_vals, 8)) # Move from the fixed grid (given by s_vals) onto the entire # interval by finding critical points nearby. We **don't** do this # for the biggest peak x-value since the function has no critical # point on the outside. new_peaks = [] for peak in peaks[:-1]: new_peaks.append(CTX.findroot(err_func_diff, peak)) new_peaks.append(peaks[-1]) # Over-write the peaks. peaks = new_peaks delta_peaks = [err_func(peak) for peak in peaks] plt.plot(s_vals, delta_vals, color=HUSL_BLUE) plt.plot(peaks, delta_peaks, color='black', linestyle='None', marker='o', markersize=4) avg_peak_val = sum(itertools.imap(abs, delta_peaks)) / len(delta_peaks) plt.plot(s_vals, [avg_peak_val] * len(s_vals), linestyle='dashed', color='black') plt.plot(s_vals, [-avg_peak_val] * len(s_vals), linestyle='dashed', color='black') plt.xlabel('$s$') plt.title('$R(s) - f(s)$') filename = 'remez_equioscillating_error.png' plt.gcf().patch.set_alpha(0.0) plt.savefig(filename, bbox_inches='tight') print('Saved ' + filename) print('-' * 40) print('Peaks occur at x-values:') for peak in peaks: peak_as_double = float(str(peak)) print('- ' + peak_as_double.hex()) if __name__ == '__main__': main()
 Number of points: 1024 CTX.prec: 200 Exponent of max error: -58.472 Saved remez_equioscillating_error.png ---------------------------------------- Peaks occur at x-values: - 0x1.dda01e9bda24ap-6 - 0x1.05cc9437a79e0p-4 - 0x1.82e6cfe518d80p-4 - 0x1.ef48b9760dac4p-4 - 0x1.23e3b6364b181p-3 - 0x1.44ac2edbbe0cfp-3 - 0x1.58d6936b415e4p-3 - 0x1.5fa43fe5c91d1p-3
 from __future__ import print_function import matplotlib.pyplot as plt import mpmath import numpy as np import seaborn # NOTE: Created via: seaborn.husl_palette(6)[4] HUSL_BLUE = (0.23299120924703914, 0.639586552066035, 0.9260706093977744) DESIRED_PRECISION = 200 # Bits vs. default of 53 MAX_S = 3 - 2 * np.sqrt(2) L1 = float.fromhex('0x1.5555555555593p-1') L2 = float.fromhex('0x1.999999997fa04p-2') L3 = float.fromhex('0x1.2492494229359p-2') L4 = float.fromhex('0x1.c71c51d8e78afp-3') L5 = float.fromhex('0x1.7466496cb03dep-3') L6 = float.fromhex('0x1.39a09d078c69fp-3') L7 = float.fromhex('0x1.2f112df3e5244p-3') CTX = mpmath.MPContext() CTX.prec = 200 # Bits vs. default of 53 def f(s): # (L1 * s**2 + L2 * s**4 + L3 * s**6 + L4 * s**8 + # L5 * s**10 + L6 * s**12 + L7 * s**14) # Computes f(s) the exact same way is done in math/log.go. s2 = s * s s4 = s2 * s2 t1 = s2 * (L1 + s4 * (L3 + s4 * (L5 + s4 * L7))) t2 = s4 * (L2 + s4 * (L4 + s4 * L6)) return t1 + t2 def f_mpf(s): # Same as f() but in high-precision. return f(CTX.mpf(s)) def R_mpf(s): s = CTX.mpf(s) if s == CTX.mpf(0.0): return CTX.mpf(0.0) else: numer = CTX.log(1 + s) - CTX.log(1 - s) return numer / s - CTX.mpf(2.0) def R(s): return float(str(R_mpf(s))) def main(): num_points = 2**10 s_vals = np.linspace(0, MAX_S, num_points) f_error_vals = [] R_error_vals = [] approx_error_vals = [] approx_error_hp_vals = [] for s_val in s_vals: # NOTE: Many computations are wasted / performed multiple times. f_val = f(s_val) f_hp_val = f_mpf(s_val) # f_hp == f in high-precision f_error_vals.append(f_hp_val - f_val) R_val = R(s_val) R_hp_val = R_mpf(s_val) # R_hp == R in high-precision # NOTE: Could check that float(str(R_mpf(s_val))) == R_val. R_error_vals.append(R_hp_val - R_val) # Now save the **actual errors**. approx_error_vals.append(R_val - f_val) approx_error_hp_vals.append(R_hp_val - f_hp_val) max_err = np.max([np.max(f_error_vals), np.max(R_error_vals), np.max(approx_error_vals), np.max(approx_error_hp_vals)]) min_err = np.min([np.min(f_error_vals), np.min(R_error_vals), np.min(approx_error_vals), np.min(approx_error_hp_vals)]) err_height = max_err - min_err err_mid = 0.5 * (max_err + min_err) # H/T: http://stackoverflow.com/a/16542886/1068170 plt.subplot(221) plt.plot(s_vals, f_error_vals, color=HUSL_BLUE) plt.ylim(err_mid - 0.55 * err_height, err_mid + 0.55 * err_height) plt.title(r'$f^{hp}(s) - f(s)$') plt.subplot(222) plt.plot(s_vals, R_error_vals, color=HUSL_BLUE) plt.ylim(err_mid - 0.55 * err_height, err_mid + 0.55 * err_height) plt.title(r'$R^{hp}(s) - R(s)$') plt.subplot(212) plt.plot(s_vals, approx_error_vals, color=HUSL_BLUE) plt.plot(s_vals, approx_error_hp_vals, linestyle='dashed', linewidth=0.5, color='black') plt.ylim(err_mid - 0.55 * err_height, err_mid + 0.55 * err_height) plt.title('$R(s) - f(s)$') filename = 'remez_equioscillating_error_ieee754.png' curr_fig = plt.gcf() width, height = curr_fig.get_size_inches() curr_fig.set_size_inches(1.5 * width, height) plt.savefig(filename, bbox_inches='tight') print('Saved ' + filename) if __name__ == '__main__': main()
 \documentclass[letterpaper,10pt]{article} \usepackage[margin=1in]{geometry} \usepackage{amsthm,amssymb,amsmath,hyperref} \hypersetup{colorlinks=true,urlcolor=blue,citecolor=blue} \usepackage{embedfile} \embedfile{\jobname.tex} \usepackage{fancyhdr} \pagestyle{fancy} \lhead{} \rhead{Remez Approximation} \renewcommand{\headrulewidth}{0pt} \begin{document} \section{Background} The Golang \href{https://golang.org/src/math/log.go}{source} for the \verb+log()+ function\footnote{The Golang source actually mentions that this method and some of the comments were borrowed from original C code, from FreeBSD's \texttt{/usr/src/lib/msun/src/e\_log.c}} %% /usr/src/lib/msun/src/e_log.c states that the natural log is computed in three parts \subsection{Argument Reduction} Find $$k$$ and $$f$$ such that $x = 2^k \left(1 + f\right)$ where $$\frac{\sqrt{2}}{2} < 1 + f < \sqrt{2}$$. \subsection{Approximation of \texorpdfstring{$\log(1+f)$}{log(1+f)}} Let $$s = \frac{f}{2 + f}$$, based on \begin{align*} \log(1 + f) &= \log(1 + s) - \log(1 - s) \\ &= 2s + \frac{2}{3} s^3 + \frac{2}{5} s^5 + \cdots \\ &= 2s + s R\left(s\right) \end{align*} Use a special Reme\footnote{misspelling, should be \href{https://en.wikipedia.org/wiki/Remez_algorithm}{Remez}} algorithm on $$\left[0, 0.1716\right]$$ to generate a polynomial of degree $$14$$ to approximate $$R$$. This approximation is given by\footnote{the $$z$$ is likely $$s^2$$ which likely means the polynomial is intentionally degree $$7$$ in $$z$$} $R(z) \approx L_1 s^2 + L_2 s^4 + L_3 s^6 + L_4 s^8 + L_5 s^{10} + L_6 s^{12} + L_7 s^{14}$ with maximum error $$2^{-58.45}$$. \subsection{Combining} Finally \begin{align*} \log(x) &= k \log(2) + \log(1 + f) \\ &= k H + \left(f - \left(\frac{f^2}{2} - \left(s \left( \frac{f^2}{2} + R\right) + k L\right)\right)\right) \end{align*} where $\log(2) = H + L$ splits as high and low parts \begin{align*} H &= 2^{-1} \cdot \mathtt{1.62e42fee00000}_{16} \\ L &= 2^{-1} \cdot \mathtt{0.00000001a39ef35793c76}_{16} = 2^{-33} \cdot \mathtt{1.a39ef35793c76}_{16} \\ \log(2) &= 2^{-1} \cdot \mathtt{1.62e42fefa39ef}_{16} \end{align*} such that the high part satisfies $$nH$$ is always exact for $$\left|n\right| < 2000$$. \section{Reducing Value} First note that $2^{-1} < 2^{-\frac{1}{2}} < 2^0 < 2^{\frac{1}{2}} < 2^1 \quad \text{i.e.} \quad \frac{1}{2} < \frac{\sqrt{2}}{2} < 1 < \sqrt{2} < 2.$ In Golang the \href{https://golang.org/src/math/frexp.go}{function} \verb+math.Frexp+ takes a float $$x$$ and returns an exponent $$e$$ and fractional value $$m$$ such that $x = 2^e m, \qquad 2\left|m\right| \in \left[1, 2\right).$ Since $$\log(x)$$ defined requires $$x > 0$$ we know $$m > 0$$. If $$\frac{1}{\sqrt{2}} < m < 1$$ then we set $$k = e$$ and $$1 + f = m$$. If $$\frac{1}{2} \leq m \leq \frac{1}{\sqrt{2}}$$ then we set $$1 + f = 2m \in \left[1, \sqrt{2}\right]$$ and $$k = e - 1$$ since $$2^e m = 2^{e - 1}(2m)$$. This would appear to only give $\frac{1}{\sqrt{2}} < 1 + f \leq \sqrt{2}$ but we also know $$1 + f \neq \sqrt{2}$$ since an irrational can't be represented in floating point. \section{From \texorpdfstring{$f$}{f} to \texorpdfstring{$s$}{s}} We write $$1 + f = \displaystyle\frac{1 + s}{1 - s}$$ so that \begin{align*} \log(1 + f) &= \log(1 + s) - \log(1 - s) \\ &= 2s + \frac{2}{3} s^3 + \frac{2}{5} s^5 + \cdots \end{align*} Solving for $$s$$ in the above yields $$s = \displaystyle\frac{f}{2 + f} = 1 - \frac{2}{2 + f}$$ and $\frac{1}{\sqrt{2}} < 1 + f < \sqrt{2} \Longrightarrow -\left(3 - 2\sqrt{2}\right) < s < 3 - 2 \sqrt{2} \approx 0.171573 < 0.1716.$ \subsection{Why the bounds?} The bound $$\frac{1}{\sqrt{2}} < 1 + f < \sqrt{2}$$ was just given but never justified. In fact, this interval is required to be able to use an approximation of $$R(s)$$. First, in order to factor $$x = 2^k (1 + f)$$ uniquely, if $$1 + f \in \left[\alpha, \beta\right]$$ we must have $\log_2 \beta = \log_2 \alpha + 1 \Longleftrightarrow \beta = 2 \alpha.$ Second, since defining $$R(s)$$ requires both $$\log(1 \pm s)$$ to be defined, we need $$-1 < s < 1$$. Third, the symmetry $$R(s) = R(-s)$$ means we limit to some $$-A < s < A$$ (for $$0 < A \leq 1$$). This in turn means that $\frac{1}{B} < 1 + f < B, \qquad B = \frac{1 + A}{1 - A} \Longleftrightarrow A = \frac{B - 1}{B + 1}.$ Hence we put $$1 + f \in \left[B^{-1}, B\right]$$ which forces $\log_2 B = \log_2\left(B^{-1}\right) + 1 \Longrightarrow 2 \log_2 B = 1 \Longrightarrow B = 2^{\frac{1}{2}} = \sqrt{2} \Longrightarrow A = 3 - 2 \sqrt{2}.$ \section{Remez Algorithm} We seek to approximate $R(s) = \frac{2s^2}{3} + \frac{2s^4}{5} + \frac{2s^6}{7} + \cdots = \frac{\log(1 + s) - \log(1 - s)}{s} - 2, \quad R(0) = 0$ with a degree $$14$$ polynomial that gives equi-oscillating errors. By construction $$R(s) = R(-s)$$, hence we need to be able to approximate $$R(s)$$ for $$s \in \left[0, 3 - 2 \sqrt{2}\right] \subset \left[0, 0.1716\right]$$. We manually force an equi-oscillating error at node points $$s_0, s_1, \ldots, s_7$$ by setting $R(s_j) = L_2 s_j^2 + L_4 s_j^4 + \cdots + L_{14} s_j^{14} + (-1)^j E = P(s_j) \pm E$ and solving for the $$7$$ unknown coefficients and the error $$E$$. This gives a system $\left[\begin{array}{c c c c} s_0^2 & \cdots & s_0^{14} & 1 \\ s_1^2 & \cdots & s_1^{14} & -1 \\ \vdots & \multicolumn{2}{c}{\ddots} & \vdots \\ s_7^2 & \cdots & s_7^{14} & -1 \end{array}\right] \left[\begin{array}{c} L_2 \\ \vdots \\ L_{14} \\ E \end{array}\right] = \left[\begin{array}{c} R(s_0) \\ R(s_1) \\ \vdots \\ R(s_7) \end{array}\right].$ After solving, we exchange $$s_0, s_1, \ldots, s_7$$ for new points which maximize $$\left|R(s) - P(s)\right|$$ locally. The method terminates once the exchange process results in no change at all (or a minimal change). An alternative approach only swaps a single $$s_j$$ at a time. It finds the absolute extreme $s^{\ast} = \underset{{s \in \left[0, 3 - 2 \sqrt{2}\right]}}{ \operatorname{argmax}} \left|R(s) - P(s)\right|$ and then swap $$s^{\ast}$$ with the nearest value among the $$\left\{s_j\right\}$$ and only terminates once the absolute extreme occurs among the $$\left\{s_j\right\}$$. In either situation, if there is no exchange left to do $\max_{s \in \left[0, 3 - 2 \sqrt{2}\right]} \left|R(s) - P(s)\right| = \left|R\left(s^{\ast}\right) - P\left(s^{\ast}\right)\right| = \left|\pm E\right| = E$ and this equi-oscillating error will occur a maximal number of times in our interval. \end{document}
 from __future__ import print_function import matplotlib.pyplot as plt import mpmath import seaborn # NOTE: Created via: seaborn.husl_palette(6)[4] HUSL_BLUE = (0.23299120924703914, 0.639586552066035, 0.9260706093977744) EXPECTED_COEFFS = [ '0x1.5555555555593p-1', '0x1.999999997fa04p-2', '0x1.2492494229359p-2', '0x1.c71c51d8e78afp-3', '0x1.7466496cb03dep-3', '0x1.39a09d078c69fp-3', '0x1.2f112df3e5244p-3', ] EXPONENTS = (2, 4, 6, 8, 10, 12, 14) SIZE_INTERVAL = 512 NUM_POINTS = 8 MAX_STEPS = 20 CTX = mpmath.MPContext() CTX.prec = 200 # Bits vs. default of 53 # NOTE: Slightly larger than 3 - 2 * mpmath.sqrt(2) # so that we can capture an 8th equi-oscillating # point. MAX_X = CTX.mpf('0.1717') INTERVAL_SAMPLE = CTX.linspace(0, MAX_X, SIZE_INTERVAL) def exp_by_squaring(val, n): """Exponentiate by squaring. :type val: float :param val: A value to exponentiate. :type n: int :param n: The (positive) exponent. :rtype: float :returns: The exponentiated value. """ result = 1 pow_val = val while n != 0: n, remainder = divmod(n, 2) if remainder == 1: result *= pow_val pow_val = pow_val * pow_val return result def R_scalar(s): """Compute the value of R(s). R(s) is function such that log(1 + s) - log(1 - s) = s(2 + R(s)) Uses mpmath to compute the answer in high precision. :type s: float :param s: A scalar to evaluate R(s) at. :rtype: float :returns: The value of R(s) at our input value. """ if s == 0.0: # We can't divide by 0, but we know that # R(0) = 2 * 0^2/3 + 2 * 0^4 / 5 + ... = 0. return CTX.mpf(0.0) numer = CTX.log(1 + s) - CTX.log(1 - s) return numer / s - CTX.mpf(2.0) def update_remez_poly(x_vals): """Updates the Remez polynomial coefficients. :type x_vals: list :param x_vals: The x-values where equi-oscillating occurs. We assume there are NUM_POINTS values. :rtype: tuple :returns: A pair, the first is the coefficients (as a list) and the second is a scalar (the equi-oscillating error). """ # Columns correspond to an exponent while rows correspond to # an x-value. coeff_system = CTX.matrix(NUM_POINTS, NUM_POINTS) rhs = CTX.matrix(NUM_POINTS, 1) for row in range(NUM_POINTS): # Handle the final column first (just a sign). coeff_system[row, NUM_POINTS - 1] = (-1)**row x_val = x_vals[row] rhs[row, 0] = R_scalar(x_val) # Handle all columns left (final column already done). for col in range(NUM_POINTS - 1): pow_ = EXPONENTS[col] coeff_system[row, col] = exp_by_squaring(x_val, pow_) soln = CTX.lu_solve(coeff_system, rhs, real=True) # Turn into a row vector and then turn it into a list. soln, = soln.T.tolist() return soln[:-1], soln[-1] def get_chebyshev_points(num_points): """Get Chebyshev points for [0, MAX_X]. :type num_points: int :param num_points: The number of points to use. :rtype: list :returns: The Chebyshev points for our interval. """ result = [] for index in range(2 * num_points - 1, 0, -2): theta = CTX.pi * index / CTX.mpf(2.0 * num_points) result.append(0.5 * MAX_X * (1 + CTX.cos(theta))) return result class ErrorFunc(object): """Error function for representing R(s) and P(s). Takes a given set of x-values, computes the necessary coefficients of P(s) with them and uses them to define the function R(s) - P(s). :type x_vals: list :param x_vals: A 1D array. The x-values where equi-oscillating occurs. """ def __init__(self, x_vals): self.x_vals = x_vals # Computed values. self.poly_coeffs, self.E = update_remez_poly(x_vals) def poly_approx_scalar(self, value): """Evaluate the polynomial :math:f(x). Uses the same method as in math/log.go to compute .. math:: L_1 x^2 + L_2 x^4 + \\cdots + L_7 x^{14} :type value: float :param value: The value to compute f(x) at. :rtype: float :returns: The value of f(x) at our input value. """ L1 = self.poly_coeffs[0] L2 = self.poly_coeffs[1] L3 = self.poly_coeffs[2] L4 = self.poly_coeffs[3] L5 = self.poly_coeffs[4] L6 = self.poly_coeffs[5] L7 = self.poly_coeffs[6] s2 = value * value s4 = s2 * s2 t1 = s2 * (L1 + s4 * (L3 + s4 * (L5 + s4 * L7))) t2 = s4 * (L2 + s4 * (L4 + s4 * L6)) return t1 + t2 def signed_error_scalar(self, value): """Error function R(s) - P(s). :type value: float :param value: An x-value. :rtype: float :returns: The value of the error R(s) - P(s). """ return R_scalar(value) - self.poly_approx_scalar(value) def signed_error_diff(self, value): """Derivative of error function R'(s) - P'(s). :type value: float :param value: An x-value. :rtype: float :returns: The value of the error R'(s) - P'(s). """ return CTX.diff(self.signed_error_scalar, value) def signed_error(self, values): """Error function R(s) - P(s). :type values: list :param values: A list of x-values. :rtype: list :returns: The value of the error R(s) - P(s) at each point. """ result = [] for val in values: result.append(self.signed_error_scalar(val)) return result def locate_abs_max(values): """Locate the absolute maximum of a list of values. :type values: list :param values: A list of scalar values. :rtype: int :returns: The index where the maximum occurs. """ curr_max = -CTX.inf curr_max_index = -1 for index, value in enumerate(values): abs_val = abs(value) if abs_val > curr_max: curr_max = abs_val curr_max_index = index return curr_max_index def get_peaks(x_data, y_data, num_peaks): """Get the peaks from oscillating output data. :type x_data: list :param x_data: The x-values where the outputs occur. :type y_data: list :param y_data: The oscillating output data. :type num_peaks: int :param num_peaks: The number of peaks to locate. :rtype: list :returns: The x-locations of all the peaks that were found, in the order that they were found. """ local_data = y_data[::] size_interval = len(y_data) peak_locations = [] while len(peak_locations) < num_peaks: curr_biggest = locate_abs_max(local_data) x_now = x_data[curr_biggest] if x_now in peak_locations: raise ValueError('Repeat value found.') peak_locations.append(x_now) # Find the sign so we can identify all nearby points on the # peak (they will have the same sign). sign_x_now = CTX.sign(local_data[curr_biggest]) local_data[curr_biggest] = 0.0 # Zero out all the values on the same peak to the # right of x_now. index = curr_biggest + 1 while (index < size_interval and CTX.sign(local_data[index]) == sign_x_now): local_data[index] = 0.0 index += 1 # Zero out all the values on the same peak to the # left of x_now. index = curr_biggest - 1 while (index >= 0 and CTX.sign(local_data[index]) == sign_x_now): local_data[index] = 0.0 index -= 1 return peak_locations def get_new_x_vals(x_vals, sample_points=INTERVAL_SAMPLE): """Perform single pass of Remez algorithm. First computes the coefficients based on x_vals, then locates the extrema of |P(s) - R(s)|. :type x_vals: list :param x_vals: The x-values where equi-oscillating occurs. :type sample_points: list :param sample_points: (Optional) The points we choose extrema from. :rtype: list :returns: The new x-values. """ err_func = ErrorFunc(x_vals) size_interval = len(sample_points) approx_outputs = err_func.signed_error(sample_points) max_vals = get_peaks(sample_points, approx_outputs, NUM_POINTS) max_vals = sorted(max_vals) # Move from the fixed grid (given by sample_points) onto the entire # interval by finding critical points nearby. We **don't** do this # for the biggest max value since the function has no critical point on # the outside. new_vals = [] for val in max_vals[:-1]: new_vals.append(CTX.findroot(err_func.signed_error_diff, val)) new_vals.append(max_vals[-1]) # NOTE: We could / should check that new_vals == sorted(new_vals) # and that it has no repeats. return new_vals def plot_x_vals(x_vals, sample_points, filename=None): """Plot the error R(s) - P(s) for P(s) given by x-values. Also plots the locations of each equi-oscillating x-value on the curve. :type x_vals: list :param x_vals: The x-values where equi-oscillating occurs. :type sample_points: list :param sample_points: The points we choose extrema from. :type filename: str :param filename: (Optional) The filename to save the plot in. If not specified, just shows the plot. """ err_func = ErrorFunc(x_vals) approx_outputs = err_func.signed_error(sample_points) plt.plot(sample_points, approx_outputs, color=HUSL_BLUE) plt.plot(x_vals, err_func.signed_error(x_vals), marker='o', color='black', linestyle='None') if filename is None: plt.show() else: plt.savefig(filename, bbox_inches='tight') print('Saved ' + filename) def _list_delta_norm(list1, list2): return CTX.norm(CTX.matrix(list1) - CTX.matrix(list2), p=2) def _lead_substr_match(val1, val2): min_len = min(len(val1), len(val2)) for index in range(1, min_len + 1): if val1[:index] != val2[:index]: return index - 1 return min_len def _print_double_hex_compare(actual_mpf, expected): # Convert back to double, then to hex. actual_as_hex = float(str(actual_mpf)).hex() print('- ' + actual_as_hex) expected_begin, expected_expon = expected.split('p', 1) match_index = _lead_substr_match(actual_as_hex, expected_begin) buffer = '.' * (len(expected_begin) - match_index) print(' ' + expected_begin[:match_index] + buffer + 'p' + expected_expon) def main(sample_points=INTERVAL_SAMPLE, threshold=CTX.mpf(2)**(-26), plot_all=False): """Run the Remez algorithm until termination. :type sample_points: list :param sample_points: (Optional) The points we choose extrema from. :type threshold: float :param threshold: (Optional) The minimum value of the norm of the difference between current x-values and the next set of x-values found via the Remez algorithm. Defaults to 2^{-26}. :type plot_all: bool :param plot_all: (Optional) Flag indicating all updates should be plotted. Defaults to :data:False. """ prev_x_vals = [CTX.inf] * NUM_POINTS x_vals = get_chebyshev_points(NUM_POINTS) num_steps = 0 while _list_delta_norm(x_vals, prev_x_vals) > threshold: if num_steps >= MAX_STEPS: print('Max. steps encountered. Does not converge.') break if plot_all: plot_x_vals(x_vals, sample_points) prev_x_vals = x_vals x_vals = get_new_x_vals(x_vals, sample_points=sample_points) num_steps += 1 norm_update = _list_delta_norm(x_vals, prev_x_vals) msg = 'Difference between successive x-vectors: %g' % (norm_update,) print(msg) msg = 'Completed in %d steps' % (num_steps,) print(msg) err_func = ErrorFunc(x_vals) print('Coefficients:') for index, coeff in enumerate(err_func.poly_coeffs): _print_double_hex_compare(coeff, EXPECTED_COEFFS[index]) plot_x_vals(x_vals, sample_points, filename='remez_approx.png') if __name__ == '__main__': main()