Instantly share code, notes, and snippets.

Embed
What would you like to do?
Remez Algorithm for log(x)
*.pyc
*.aux
*.log
*.out
from __future__ import print_function
import matplotlib.pyplot as plt
import mpmath
import numpy as np
import seaborn
# NOTE: Created via: seaborn.husl_palette(6)[4]
HUSL_BLUE = (0.23299120924703914, 0.639586552066035, 0.9260706093977744)
LOG2_HI = float.fromhex('0x1.62e42fee00000p-1')
LOG2_LOW = float.fromhex('0x1.a39ef35793c76p-33')
L1 = float.fromhex('0x1.5555555555593p-1')
L2 = float.fromhex('0x1.999999997fa04p-2')
L3 = float.fromhex('0x1.2492494229359p-2')
L4 = float.fromhex('0x1.c71c51d8e78afp-3')
L5 = float.fromhex('0x1.7466496cb03dep-3')
L6 = float.fromhex('0x1.39a09d078c69fp-3')
L7 = float.fromhex('0x1.2f112df3e5244p-3')
SQRT2_HALF = float.fromhex('0x1.6a09e667f3bcdp-1')
CTX = mpmath.MPContext()
CTX.prec = 200 # Bits vs. default of 53
def log_mpf(x):
return CTX.log(CTX.mpf(x))
def log_ieee754(x):
f1, ki = np.frexp(x)
if f1 < SQRT2_HALF:
f1 *= 2
ki -= 1
f = f1 - 1
k = float(ki)
s = f / (2 + f)
s2 = s * s
s4 = s2 * s2
# Terms with odd powers of s^2.
t1 = s2 * (L1 + s4 * (L3 + s4 * (L5 + s4 * L7)))
# Terms with even powers of s^2.
t2 = s4 * (L2 + s4 * (L4 + s4 * L6))
R = t1 + t2
hfsq = 0.5 * f * f
return k * LOG2_HI - ((hfsq - (s * (hfsq + R) + k * LOG2_LOW)) - f)
def main():
num_points = 2**14
x_vals = np.linspace(0, np.exp(8.0), num_points)
x_vals = x_vals[1:] # Eliminate 0.0 since log(0) = -inf.
log_rel_errors = []
for x_val in x_vals:
log_val = log_ieee754(x_val)
log_hp_val = log_mpf(x_val)
log_rel_errors.append(abs(log_val - log_hp_val) / abs(log_hp_val))
plt.plot(x_vals, log_rel_errors, color=HUSL_BLUE)
filename = 'log_high_precision_relative_error.png'
plt.savefig(filename, bbox_inches='tight')
print('Saved ' + filename)
if __name__ == '__main__':
main()
from __future__ import print_function
import itertools
import matplotlib.pyplot as plt
import mpmath
import seaborn
from remez_w_high_precision import get_peaks
# NOTE: Created via: seaborn.husl_palette(6)[4]
HUSL_BLUE = (0.23299120924703914, 0.639586552066035, 0.9260706093977744)
CTX = mpmath.MPContext()
CTX.prec = 200 # Bits vs. default of 53
L1 = CTX.mpf(float.fromhex('0x1.5555555555593p-1'))
L2 = CTX.mpf(float.fromhex('0x1.999999997fa04p-2'))
L3 = CTX.mpf(float.fromhex('0x1.2492494229359p-2'))
L4 = CTX.mpf(float.fromhex('0x1.c71c51d8e78afp-3'))
L5 = CTX.mpf(float.fromhex('0x1.7466496cb03dep-3'))
L6 = CTX.mpf(float.fromhex('0x1.39a09d078c69fp-3'))
L7 = CTX.mpf(float.fromhex('0x1.2f112df3e5244p-3'))
MAX_S = 3 - 2 * CTX.sqrt(2)
EXTRA_S = CTX.mpf('0.1717')
def f(s):
# (L1 * s**2 + L2 * s**4 + L3 * s**6 + L4 * s**8 +
# L5 * s**10 + L6 * s**12 + L7 * s**14)
# Computes f(s) the exact same way is done in math/log.go.
s2 = s * s
s4 = s2 * s2
t1 = s2 * (L1 + s4 * (L3 + s4 * (L5 + s4 * L7)))
t2 = s4 * (L2 + s4 * (L4 + s4 * L6))
return t1 + t2
def R(s):
if s == CTX.mpf(0.0):
return CTX.mpf(0.0)
else:
numer = CTX.log(1 + s) - CTX.log(1 - s)
return numer / s - CTX.mpf(2.0)
def err_func(s):
return R(s) - f(s)
def err_func_diff(s):
return CTX.diff(err_func, s)
def main():
num_points = 2**10
extra_points = 32
# Extend past the true max to find an 8th point where
# equi-oscillation occurs.
s_vals = (CTX.linspace(0, MAX_S, num_points) +
CTX.linspace(MAX_S, EXTRA_S, extra_points)[1:])
delta_vals = [err_func(s_val) for s_val in s_vals]
abs_vals_iter = itertools.imap(abs, delta_vals)
exponent = CTX.log(max(abs_vals_iter), 2)
print(' Number of points: % d' % (num_points,))
print(' CTX.prec: % d' % (CTX.prec,))
print('Exponent of max error: % 2.3f' % (exponent,))
peaks = sorted(get_peaks(s_vals, delta_vals, 8))
# Move from the fixed grid (given by ``s_vals``) onto the entire
# interval by finding critical points nearby. We **don't** do this
# for the biggest peak x-value since the function has no critical
# point on the outside.
new_peaks = []
for peak in peaks[:-1]:
new_peaks.append(CTX.findroot(err_func_diff, peak))
new_peaks.append(peaks[-1])
# Over-write the peaks.
peaks = new_peaks
delta_peaks = [err_func(peak) for peak in peaks]
plt.plot(s_vals, delta_vals, color=HUSL_BLUE)
plt.plot(peaks, delta_peaks, color='black', linestyle='None',
marker='o', markersize=4)
avg_peak_val = sum(itertools.imap(abs, delta_peaks)) / len(delta_peaks)
plt.plot(s_vals, [avg_peak_val] * len(s_vals),
linestyle='dashed', color='black')
plt.plot(s_vals, [-avg_peak_val] * len(s_vals),
linestyle='dashed', color='black')
plt.xlabel('$s$')
plt.title('$R(s) - f(s)$')
filename = 'remez_equioscillating_error.png'
plt.gcf().patch.set_alpha(0.0)
plt.savefig(filename, bbox_inches='tight')
print('Saved ' + filename)
print('-' * 40)
print('Peaks occur at x-values:')
for peak in peaks:
peak_as_double = float(str(peak))
print('- ' + peak_as_double.hex())
if __name__ == '__main__':
main()
Number of points: 1024
CTX.prec: 200
Exponent of max error: -58.472
Saved remez_equioscillating_error.png
----------------------------------------
Peaks occur at x-values:
- 0x1.dda01e9bda24ap-6
- 0x1.05cc9437a79e0p-4
- 0x1.82e6cfe518d80p-4
- 0x1.ef48b9760dac4p-4
- 0x1.23e3b6364b181p-3
- 0x1.44ac2edbbe0cfp-3
- 0x1.58d6936b415e4p-3
- 0x1.5fa43fe5c91d1p-3
from __future__ import print_function
import matplotlib.pyplot as plt
import mpmath
import numpy as np
import seaborn
# NOTE: Created via: seaborn.husl_palette(6)[4]
HUSL_BLUE = (0.23299120924703914, 0.639586552066035, 0.9260706093977744)
DESIRED_PRECISION = 200 # Bits vs. default of 53
MAX_S = 3 - 2 * np.sqrt(2)
L1 = float.fromhex('0x1.5555555555593p-1')
L2 = float.fromhex('0x1.999999997fa04p-2')
L3 = float.fromhex('0x1.2492494229359p-2')
L4 = float.fromhex('0x1.c71c51d8e78afp-3')
L5 = float.fromhex('0x1.7466496cb03dep-3')
L6 = float.fromhex('0x1.39a09d078c69fp-3')
L7 = float.fromhex('0x1.2f112df3e5244p-3')
CTX = mpmath.MPContext()
CTX.prec = 200 # Bits vs. default of 53
def f(s):
# (L1 * s**2 + L2 * s**4 + L3 * s**6 + L4 * s**8 +
# L5 * s**10 + L6 * s**12 + L7 * s**14)
# Computes f(s) the exact same way is done in math/log.go.
s2 = s * s
s4 = s2 * s2
t1 = s2 * (L1 + s4 * (L3 + s4 * (L5 + s4 * L7)))
t2 = s4 * (L2 + s4 * (L4 + s4 * L6))
return t1 + t2
def f_mpf(s):
# Same as f() but in high-precision.
return f(CTX.mpf(s))
def R_mpf(s):
s = CTX.mpf(s)
if s == CTX.mpf(0.0):
return CTX.mpf(0.0)
else:
numer = CTX.log(1 + s) - CTX.log(1 - s)
return numer / s - CTX.mpf(2.0)
def R(s):
return float(str(R_mpf(s)))
def main():
num_points = 2**10
s_vals = np.linspace(0, MAX_S, num_points)
f_error_vals = []
R_error_vals = []
approx_error_vals = []
approx_error_hp_vals = []
for s_val in s_vals:
# NOTE: Many computations are wasted / performed multiple times.
f_val = f(s_val)
f_hp_val = f_mpf(s_val) # f_hp == f in high-precision
f_error_vals.append(f_hp_val - f_val)
R_val = R(s_val)
R_hp_val = R_mpf(s_val) # R_hp == R in high-precision
# NOTE: Could check that float(str(R_mpf(s_val))) == R_val.
R_error_vals.append(R_hp_val - R_val)
# Now save the **actual errors**.
approx_error_vals.append(R_val - f_val)
approx_error_hp_vals.append(R_hp_val - f_hp_val)
max_err = np.max([np.max(f_error_vals),
np.max(R_error_vals),
np.max(approx_error_vals),
np.max(approx_error_hp_vals)])
min_err = np.min([np.min(f_error_vals),
np.min(R_error_vals),
np.min(approx_error_vals),
np.min(approx_error_hp_vals)])
err_height = max_err - min_err
err_mid = 0.5 * (max_err + min_err)
# H/T: http://stackoverflow.com/a/16542886/1068170
plt.subplot(221)
plt.plot(s_vals, f_error_vals, color=HUSL_BLUE)
plt.ylim(err_mid - 0.55 * err_height, err_mid + 0.55 * err_height)
plt.title(r'$f^{hp}(s) - f(s)$')
plt.subplot(222)
plt.plot(s_vals, R_error_vals, color=HUSL_BLUE)
plt.ylim(err_mid - 0.55 * err_height, err_mid + 0.55 * err_height)
plt.title(r'$R^{hp}(s) - R(s)$')
plt.subplot(212)
plt.plot(s_vals, approx_error_vals, color=HUSL_BLUE)
plt.plot(s_vals, approx_error_hp_vals,
linestyle='dashed', linewidth=0.5, color='black')
plt.ylim(err_mid - 0.55 * err_height, err_mid + 0.55 * err_height)
plt.title('$R(s) - f(s)$')
filename = 'remez_equioscillating_error_ieee754.png'
curr_fig = plt.gcf()
width, height = curr_fig.get_size_inches()
curr_fig.set_size_inches(1.5 * width, height)
plt.savefig(filename, bbox_inches='tight')
print('Saved ' + filename)
if __name__ == '__main__':
main()
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
\documentclass[letterpaper,10pt]{article}
\usepackage[margin=1in]{geometry}
\usepackage{amsthm,amssymb,amsmath,hyperref}
\hypersetup{colorlinks=true,urlcolor=blue,citecolor=blue}
\usepackage{embedfile}
\embedfile{\jobname.tex}
\usepackage{fancyhdr}
\pagestyle{fancy}
\lhead{}
\rhead{Remez Approximation}
\renewcommand{\headrulewidth}{0pt}
\begin{document}
\section{Background}
The Golang \href{https://golang.org/src/math/log.go}{source} for
the \verb+log()+
function\footnote{The Golang source actually mentions that this method and some of the comments were borrowed from original C code, from FreeBSD's \texttt{/usr/src/lib/msun/src/e\_log.c}}
%% /usr/src/lib/msun/src/e_log.c
states that the natural log is computed in three parts
\subsection{Argument Reduction}
Find \(k\) and \(f\) such that
\[x = 2^k \left(1 + f\right)\]
where \(\frac{\sqrt{2}}{2} < 1 + f < \sqrt{2}\).
\subsection{Approximation of \texorpdfstring{$\log(1+f)$}{log(1+f)}}
Let \(s = \frac{f}{2 + f}\), based on
\begin{align*}
\log(1 + f) &= \log(1 + s) - \log(1 - s) \\
&= 2s + \frac{2}{3} s^3 + \frac{2}{5} s^5 + \cdots \\
&= 2s + s R\left(s\right)
\end{align*}
Use a special
Reme\footnote{misspelling, should be \href{https://en.wikipedia.org/wiki/Remez_algorithm}{Remez}}
algorithm on \(\left[0, 0.1716\right]\) to generate a polynomial of
degree \(14\) to approximate \(R\).
This approximation is given
by\footnote{the \(z\) is likely \(s^2\) which likely means the polynomial is intentionally degree \(7\) in \(z\)}
\[R(z) \approx L_1 s^2 + L_2 s^4 + L_3 s^6 + L_4 s^8 + L_5 s^{10} +
L_6 s^{12} + L_7 s^{14}\]
with maximum error \(2^{-58.45}\).
\subsection{Combining}
Finally
\begin{align*}
\log(x) &= k \log(2) + \log(1 + f) \\
&= k H + \left(f - \left(\frac{f^2}{2} - \left(s \left(
\frac{f^2}{2} + R\right) + k L\right)\right)\right)
\end{align*}
where
\[\log(2) = H + L\]
splits as high and low parts
\begin{align*}
H &= 2^{-1} \cdot \mathtt{1.62e42fee00000}_{16} \\
L &= 2^{-1} \cdot \mathtt{0.00000001a39ef35793c76}_{16} =
2^{-33} \cdot \mathtt{1.a39ef35793c76}_{16} \\
\log(2) &= 2^{-1} \cdot \mathtt{1.62e42fefa39ef}_{16}
\end{align*}
such that the high part satisfies \(nH\) is always exact for
\(\left|n\right| < 2000\).
\section{Reducing Value}
First note that
\[2^{-1} < 2^{-\frac{1}{2}} < 2^0 < 2^{\frac{1}{2}} < 2^1
\quad \text{i.e.} \quad
\frac{1}{2} < \frac{\sqrt{2}}{2} < 1 < \sqrt{2} < 2.\]
In Golang the \href{https://golang.org/src/math/frexp.go}{function}
\verb+math.Frexp+ takes a float \(x\) and returns an exponent
\(e\) and fractional value \(m\) such that
\[x = 2^e m, \qquad 2\left|m\right| \in \left[1, 2\right).\]
Since \(\log(x)\) defined requires \(x > 0\) we know \(m > 0\).
If \(\frac{1}{\sqrt{2}} < m < 1\) then we set \(k = e\) and
\(1 + f = m\). If \(\frac{1}{2} \leq m \leq \frac{1}{\sqrt{2}}\)
then we set \(1 + f = 2m \in \left[1, \sqrt{2}\right]\) and
\(k = e - 1\) since \(2^e m = 2^{e - 1}(2m)\). This would appear
to only give
\[\frac{1}{\sqrt{2}} < 1 + f \leq \sqrt{2}\]
but we also know \(1 + f \neq \sqrt{2}\) since an irrational
can't be represented in floating point.
\section{From \texorpdfstring{$f$}{f} to \texorpdfstring{$s$}{s}}
We write \(1 + f = \displaystyle\frac{1 + s}{1 - s}\) so that
\begin{align*}
\log(1 + f) &= \log(1 + s) - \log(1 - s) \\
&= 2s + \frac{2}{3} s^3 + \frac{2}{5} s^5 + \cdots
\end{align*}
Solving for \(s\) in the above yields
\(s = \displaystyle\frac{f}{2 + f} = 1 - \frac{2}{2 + f}\) and
\[\frac{1}{\sqrt{2}} < 1 + f < \sqrt{2} \Longrightarrow
-\left(3 - 2\sqrt{2}\right) < s < 3 - 2 \sqrt{2} \approx
0.171573 < 0.1716.\]
\subsection{Why the bounds?}
The bound \(\frac{1}{\sqrt{2}} < 1 + f < \sqrt{2}\) was just given but
never justified. In fact, this interval is required to be able to
use an approximation of \(R(s)\). First, in
order to factor \(x = 2^k (1 + f)\) uniquely, if \(1 + f \in
\left[\alpha, \beta\right]\) we must have
\[\log_2 \beta = \log_2 \alpha + 1 \Longleftrightarrow
\beta = 2 \alpha.\]
Second, since defining \(R(s)\) requires both \(\log(1 \pm s)\)
to be defined, we need \(-1 < s < 1\).
Third, the symmetry \(R(s) = R(-s)\) means we limit to some
\(-A < s < A\) (for \(0 < A \leq 1\)). This in turn means that
\[\frac{1}{B} < 1 + f < B, \qquad B =
\frac{1 + A}{1 - A} \Longleftrightarrow A = \frac{B - 1}{B + 1}.\]
Hence we put \(1 + f \in \left[B^{-1}, B\right]\) which
forces
\[\log_2 B = \log_2\left(B^{-1}\right) + 1
\Longrightarrow 2 \log_2 B = 1 \Longrightarrow B =
2^{\frac{1}{2}} = \sqrt{2} \Longrightarrow A = 3 - 2 \sqrt{2}.\]
\section{Remez Algorithm}
We seek to approximate
\[R(s) = \frac{2s^2}{3} + \frac{2s^4}{5} + \frac{2s^6}{7} + \cdots =
\frac{\log(1 + s) - \log(1 - s)}{s} - 2, \quad
R(0) = 0\]
with a degree \(14\) polynomial that gives equi-oscillating errors.
By construction \(R(s) = R(-s)\), hence we need to be able to approximate
\(R(s)\) for \(s \in \left[0, 3 - 2 \sqrt{2}\right] \subset
\left[0, 0.1716\right]\).
We manually force an equi-oscillating error at node points
\(s_0, s_1, \ldots, s_7\) by setting
\[R(s_j) = L_2 s_j^2 + L_4 s_j^4 + \cdots + L_{14} s_j^{14} + (-1)^j E
= P(s_j) \pm E\]
and solving for the \(7\) unknown coefficients and the error
\(E\). This gives a system
\[\left[\begin{array}{c c c c}
s_0^2 & \cdots & s_0^{14} & 1 \\
s_1^2 & \cdots & s_1^{14} & -1 \\
\vdots & \multicolumn{2}{c}{\ddots} & \vdots \\
s_7^2 & \cdots & s_7^{14} & -1
\end{array}\right]
\left[\begin{array}{c}
L_2 \\ \vdots \\ L_{14} \\ E
\end{array}\right] =
\left[\begin{array}{c}
R(s_0) \\ R(s_1) \\ \vdots \\ R(s_7)
\end{array}\right].\]
After solving, we exchange
\(s_0, s_1, \ldots, s_7\) for new points
which maximize \(\left|R(s) - P(s)\right|\) locally.
The method terminates once the exchange process results
in no change at all (or a minimal change).
An alternative approach only swaps a single \(s_j\)
at a time. It finds the absolute extreme
\[s^{\ast} = \underset{{s \in \left[0, 3 - 2 \sqrt{2}\right]}}{
\operatorname{argmax}}
\left|R(s) - P(s)\right|\]
and then swap \(s^{\ast}\) with the nearest value among the
\(\left\{s_j\right\}\) and only terminates once the absolute
extreme occurs among the \(\left\{s_j\right\}\).
In either situation, if there is no exchange left to do
\[\max_{s \in \left[0, 3 - 2 \sqrt{2}\right]}
\left|R(s) - P(s)\right| = \left|R\left(s^{\ast}\right) -
P\left(s^{\ast}\right)\right| = \left|\pm E\right| = E\]
and this equi-oscillating error will occur a maximal number of
times in our interval.
\end{document}
Difference between successive x-vectors: 2.78724e-09
Completed in 7 steps
Coefficients:
- 0x1.5555555555593p-1
0x1.5555555555593p-1
- 0x1.999999997f9f8p-2
0x1.999999997f...p-2
- 0x1.249249422a440p-2
0x1.249249422....p-2
- 0x1.c71c51d7cf382p-3
0x1.c71c51d......p-3
- 0x1.746649afb0e69p-3
0x1.746649.......p-3
- 0x1.39a095848f9a5p-3
0x1.39a09........p-3
- 0x1.2f117fc8e24c3p-3
0x1.2f11.........p-3
Saved remez_approx.png
from __future__ import print_function
import matplotlib.pyplot as plt
import mpmath
import seaborn
# NOTE: Created via: seaborn.husl_palette(6)[4]
HUSL_BLUE = (0.23299120924703914, 0.639586552066035, 0.9260706093977744)
EXPECTED_COEFFS = [
'0x1.5555555555593p-1',
'0x1.999999997fa04p-2',
'0x1.2492494229359p-2',
'0x1.c71c51d8e78afp-3',
'0x1.7466496cb03dep-3',
'0x1.39a09d078c69fp-3',
'0x1.2f112df3e5244p-3',
]
EXPONENTS = (2, 4, 6, 8, 10, 12, 14)
SIZE_INTERVAL = 512
NUM_POINTS = 8
MAX_STEPS = 20
CTX = mpmath.MPContext()
CTX.prec = 200 # Bits vs. default of 53
# NOTE: Slightly larger than 3 - 2 * mpmath.sqrt(2)
# so that we can capture an 8th equi-oscillating
# point.
MAX_X = CTX.mpf('0.1717')
INTERVAL_SAMPLE = CTX.linspace(0, MAX_X, SIZE_INTERVAL)
def exp_by_squaring(val, n):
"""Exponentiate by squaring.
:type val: float
:param val: A value to exponentiate.
:type n: int
:param n: The (positive) exponent.
:rtype: float
:returns: The exponentiated value.
"""
result = 1
pow_val = val
while n != 0:
n, remainder = divmod(n, 2)
if remainder == 1:
result *= pow_val
pow_val = pow_val * pow_val
return result
def R_scalar(s):
"""Compute the value of R(s).
R(s) is function such that
log(1 + s) - log(1 - s) = s(2 + R(s))
Uses ``mpmath`` to compute the answer in high precision.
:type s: float
:param s: A scalar to evaluate ``R(s)`` at.
:rtype: float
:returns: The value of ``R(s)`` at our input value.
"""
if s == 0.0:
# We can't divide by 0, but we know that
# R(0) = 2 * 0^2/3 + 2 * 0^4 / 5 + ... = 0.
return CTX.mpf(0.0)
numer = CTX.log(1 + s) - CTX.log(1 - s)
return numer / s - CTX.mpf(2.0)
def update_remez_poly(x_vals):
"""Updates the Remez polynomial coefficients.
:type x_vals: list
:param x_vals: The ``x``-values where equi-oscillating occurs.
We assume there are ``NUM_POINTS`` values.
:rtype: tuple
:returns: A pair, the first is the coefficients (as a list) and the second
is a scalar (the equi-oscillating error).
"""
# Columns correspond to an exponent while rows correspond to
# an x-value.
coeff_system = CTX.matrix(NUM_POINTS, NUM_POINTS)
rhs = CTX.matrix(NUM_POINTS, 1)
for row in range(NUM_POINTS):
# Handle the final column first (just a sign).
coeff_system[row, NUM_POINTS - 1] = (-1)**row
x_val = x_vals[row]
rhs[row, 0] = R_scalar(x_val)
# Handle all columns left (final column already done).
for col in range(NUM_POINTS - 1):
pow_ = EXPONENTS[col]
coeff_system[row, col] = exp_by_squaring(x_val, pow_)
soln = CTX.lu_solve(coeff_system, rhs, real=True)
# Turn into a row vector and then turn it into a list.
soln, = soln.T.tolist()
return soln[:-1], soln[-1]
def get_chebyshev_points(num_points):
"""Get Chebyshev points for [0, MAX_X].
:type num_points: int
:param num_points: The number of points to use.
:rtype: list
:returns: The Chebyshev points for our interval.
"""
result = []
for index in range(2 * num_points - 1, 0, -2):
theta = CTX.pi * index / CTX.mpf(2.0 * num_points)
result.append(0.5 * MAX_X * (1 + CTX.cos(theta)))
return result
class ErrorFunc(object):
"""Error function for representing R(s) and P(s).
Takes a given set of ``x``-values, computes the necessary
coefficients of ``P(s)`` with them and uses them to define
the function R(s) - P(s).
:type x_vals: list
:param x_vals: A 1D array. The ``x``-values where
equi-oscillating occurs.
"""
def __init__(self, x_vals):
self.x_vals = x_vals
# Computed values.
self.poly_coeffs, self.E = update_remez_poly(x_vals)
def poly_approx_scalar(self, value):
"""Evaluate the polynomial :math:`f(x)`.
Uses the same method as in ``math/log.go`` to compute
.. math::
L_1 x^2 + L_2 x^4 + \\cdots + L_7 x^{14}
:type value: float
:param value: The value to compute ``f(x)`` at.
:rtype: float
:returns: The value of ``f(x)`` at our input value.
"""
L1 = self.poly_coeffs[0]
L2 = self.poly_coeffs[1]
L3 = self.poly_coeffs[2]
L4 = self.poly_coeffs[3]
L5 = self.poly_coeffs[4]
L6 = self.poly_coeffs[5]
L7 = self.poly_coeffs[6]
s2 = value * value
s4 = s2 * s2
t1 = s2 * (L1 + s4 * (L3 + s4 * (L5 + s4 * L7)))
t2 = s4 * (L2 + s4 * (L4 + s4 * L6))
return t1 + t2
def signed_error_scalar(self, value):
"""Error function R(s) - P(s).
:type value: float
:param value: An ``x``-value.
:rtype: float
:returns: The value of the error R(s) - P(s).
"""
return R_scalar(value) - self.poly_approx_scalar(value)
def signed_error_diff(self, value):
"""Derivative of error function R'(s) - P'(s).
:type value: float
:param value: An ``x``-value.
:rtype: float
:returns: The value of the error R'(s) - P'(s).
"""
return CTX.diff(self.signed_error_scalar, value)
def signed_error(self, values):
"""Error function R(s) - P(s).
:type values: list
:param values: A list of ``x``-values.
:rtype: list
:returns: The value of the error R(s) - P(s) at each point.
"""
result = []
for val in values:
result.append(self.signed_error_scalar(val))
return result
def locate_abs_max(values):
"""Locate the absolute maximum of a list of values.
:type values: list
:param values: A list of scalar values.
:rtype: int
:returns: The index where the maximum occurs.
"""
curr_max = -CTX.inf
curr_max_index = -1
for index, value in enumerate(values):
abs_val = abs(value)
if abs_val > curr_max:
curr_max = abs_val
curr_max_index = index
return curr_max_index
def get_peaks(x_data, y_data, num_peaks):
"""Get the peaks from oscillating output data.
:type x_data: list
:param x_data: The ``x``-values where the outputs occur.
:type y_data: list
:param y_data: The oscillating output data.
:type num_peaks: int
:param num_peaks: The number of peaks to locate.
:rtype: list
:returns: The ``x``-locations of all the peaks that were found,
in the order that they were found.
"""
local_data = y_data[::]
size_interval = len(y_data)
peak_locations = []
while len(peak_locations) < num_peaks:
curr_biggest = locate_abs_max(local_data)
x_now = x_data[curr_biggest]
if x_now in peak_locations:
raise ValueError('Repeat value found.')
peak_locations.append(x_now)
# Find the sign so we can identify all nearby points on the
# peak (they will have the same sign).
sign_x_now = CTX.sign(local_data[curr_biggest])
local_data[curr_biggest] = 0.0
# Zero out all the values on the same peak to the
# right of x_now.
index = curr_biggest + 1
while (index < size_interval and
CTX.sign(local_data[index]) == sign_x_now):
local_data[index] = 0.0
index += 1
# Zero out all the values on the same peak to the
# left of x_now.
index = curr_biggest - 1
while (index >= 0 and
CTX.sign(local_data[index]) == sign_x_now):
local_data[index] = 0.0
index -= 1
return peak_locations
def get_new_x_vals(x_vals, sample_points=INTERVAL_SAMPLE):
"""Perform single pass of Remez algorithm.
First computes the coefficients based on ``x_vals``, then locates
the extrema of ``|P(s) - R(s)|``.
:type x_vals: list
:param x_vals: The ``x``-values where equi-oscillating occurs.
:type sample_points: list
:param sample_points: (Optional) The points we choose extrema from.
:rtype: list
:returns: The new ``x``-values.
"""
err_func = ErrorFunc(x_vals)
size_interval = len(sample_points)
approx_outputs = err_func.signed_error(sample_points)
max_vals = get_peaks(sample_points, approx_outputs, NUM_POINTS)
max_vals = sorted(max_vals)
# Move from the fixed grid (given by ``sample_points``) onto the entire
# interval by finding critical points nearby. We **don't** do this
# for the biggest max value since the function has no critical point on
# the outside.
new_vals = []
for val in max_vals[:-1]:
new_vals.append(CTX.findroot(err_func.signed_error_diff, val))
new_vals.append(max_vals[-1])
# NOTE: We could / should check that new_vals == sorted(new_vals)
# and that it has no repeats.
return new_vals
def plot_x_vals(x_vals, sample_points, filename=None):
"""Plot the error R(s) - P(s) for P(s) given by ``x``-values.
Also plots the locations of each equi-oscillating ``x``-value on
the curve.
:type x_vals: list
:param x_vals: The ``x``-values where equi-oscillating occurs.
:type sample_points: list
:param sample_points: The points we choose extrema from.
:type filename: str
:param filename: (Optional) The filename to save the plot in. If
not specified, just shows the plot.
"""
err_func = ErrorFunc(x_vals)
approx_outputs = err_func.signed_error(sample_points)
plt.plot(sample_points, approx_outputs, color=HUSL_BLUE)
plt.plot(x_vals, err_func.signed_error(x_vals), marker='o',
color='black', linestyle='None')
if filename is None:
plt.show()
else:
plt.savefig(filename, bbox_inches='tight')
print('Saved ' + filename)
def _list_delta_norm(list1, list2):
return CTX.norm(CTX.matrix(list1) - CTX.matrix(list2), p=2)
def _lead_substr_match(val1, val2):
min_len = min(len(val1), len(val2))
for index in range(1, min_len + 1):
if val1[:index] != val2[:index]:
return index - 1
return min_len
def _print_double_hex_compare(actual_mpf, expected):
# Convert back to double, then to hex.
actual_as_hex = float(str(actual_mpf)).hex()
print('- ' + actual_as_hex)
expected_begin, expected_expon = expected.split('p', 1)
match_index = _lead_substr_match(actual_as_hex,
expected_begin)
buffer = '.' * (len(expected_begin) - match_index)
print(' ' + expected_begin[:match_index] +
buffer + 'p' + expected_expon)
def main(sample_points=INTERVAL_SAMPLE,
threshold=CTX.mpf(2)**(-26), plot_all=False):
"""Run the Remez algorithm until termination.
:type sample_points: list
:param sample_points: (Optional) The points we choose extrema from.
:type threshold: float
:param threshold: (Optional) The minimum value of the norm of the
difference between current ``x``-values and the
next set of ``x``-values found via the Remez
algorithm. Defaults to ``2^{-26}``.
:type plot_all: bool
:param plot_all: (Optional) Flag indicating all updates should be
plotted. Defaults to :data:`False`.
"""
prev_x_vals = [CTX.inf] * NUM_POINTS
x_vals = get_chebyshev_points(NUM_POINTS)
num_steps = 0
while _list_delta_norm(x_vals, prev_x_vals) > threshold:
if num_steps >= MAX_STEPS:
print('Max. steps encountered. Does not converge.')
break
if plot_all:
plot_x_vals(x_vals, sample_points)
prev_x_vals = x_vals
x_vals = get_new_x_vals(x_vals, sample_points=sample_points)
num_steps += 1
norm_update = _list_delta_norm(x_vals, prev_x_vals)
msg = 'Difference between successive x-vectors: %g' % (norm_update,)
print(msg)
msg = 'Completed in %d steps' % (num_steps,)
print(msg)
err_func = ErrorFunc(x_vals)
print('Coefficients:')
for index, coeff in enumerate(err_func.poly_coeffs):
_print_double_hex_compare(coeff, EXPECTED_COEFFS[index])
plot_x_vals(x_vals, sample_points, filename='remez_approx.png')
if __name__ == '__main__':
main()
"""Simple sanity check that mpmath.mpf() -> float works.
Checks the relative error in extended precision of the simple
process
mpmath.mpf('a.bcdef...') --> 'a.bcdef...' --> float('a.bcdef...')
and gives an idea of how correct the output is.
This check was inspired by a bug that I wrote which used an mpf in
the numerator and a standard Python float in the denominator,
causing unexpected results.
"""
from __future__ import print_function
import matplotlib.pyplot as plt
import mpmath
import numpy as np
import seaborn
# NOTE: Created via: seaborn.husl_palette(6)[4]
HUSL_BLUE = (0.23299120924703914, 0.639586552066035, 0.9260706093977744)
ULP_SIZE = 2.0**52 # NOTE: This assumes a 52-bit mantissa.
CTX = mpmath.MPContext()
CTX.prec = 200 # Bits vs. default of 53
def prev_float(val):
"""Gets the next float after val in IEEE 754 floating point.
>>> prev_float(float.fromhex('-0x1.fffffffffffffp+0')).hex()
'-0x1.0000000000000p+1'
>>> prev_float(float.fromhex('-0x1.0000000000000p+0')).hex()
'-0x1.0000000000001p+0'
>>> prev_float(float.fromhex('-0x1.626262626262fp-1')).hex()
'-0x1.6262626262630p-1'
>>> prev_float(float.fromhex('-0x0.0000000000001p-1022')).hex()
'-0x0.0000000000002p-1022'
>>> prev_float(float.fromhex('0x0.0p+0')).hex()
'-0x0.0000000000001p-1022'
>>> prev_float(float.fromhex('0x0.0000000000001p-1022')).hex()
'0x0.0p+0'
>>> prev_float(float.fromhex('0x1.0000000000000p+0')).hex()
'0x1.fffffffffffffp-1'
>>> prev_float(float.fromhex('0x1.fffffffffffffp+0')).hex()
'0x1.ffffffffffffep+0'
>>> prev_float(float.fromhex('0x1.bbbbbbbbbbbbap+102')).hex()
'0x1.bbbbbbbbbbbb9p+102'
"""
if val < 0:
return - next_float(-val)
else:
spacing = np.spacing(val)
if spacing * ULP_SIZE == val:
# In this case, val is a power of 2.
return val - 0.5 * spacing
else:
return val - spacing
def next_float(val):
"""Gets the next float after val in IEEE 754 floating point.
>>> next_float(float.fromhex('-0x1.fffffffffffffp+0')).hex()
'-0x1.ffffffffffffep+0'
>>> next_float(float.fromhex('-0x1.0000000000000p+0')).hex()
'-0x1.fffffffffffffp-1'
>>> next_float(float.fromhex('-0x1.626262626262fp-1')).hex()
'-0x1.626262626262ep-1'
>>> next_float(float.fromhex('-0x0.0000000000001p-1022')).hex()
'-0x0.0p+0'
>>> next_float(float.fromhex('0x0.0p+0')).hex()
'0x0.0000000000001p-1022'
>>> next_float(float.fromhex('0x0.0000000000001p-1022')).hex()
'0x0.0000000000002p-1022'
>>> next_float(float.fromhex('0x1.0000000000000p+0')).hex()
'0x1.0000000000001p+0'
>>> next_float(float.fromhex('0x1.fffffffffffffp+0')).hex()
'0x1.0000000000000p+1'
>>> next_float(float.fromhex('0x1.bbbbbbbbbbbbap+102')).hex()
'0x1.bbbbbbbbbbbbbp+102'
"""
if val < 0:
return - prev_float(-val)
else:
return val + np.spacing(val)
def main():
num_points = 2**14
mpmath_points = CTX.linspace(0, 1000, num_points)
relative_errs = []
for mp_value in mpmath_points:
float_val = float(str(mp_value))
prev_val = prev_float(float_val)
next_val = next_float(float_val)
if mp_value == float_val:
relative_errs.append(0.0)
else:
relative_errs.append((float_val - mp_value) / mp_value)
if prev_val < mp_value < float_val:
# float_val should be closer than prev_val
# <==> 0 < float_val - mp_value < mp_value - prev_val
if (CTX.mpf(float_val) - mp_value >=
mp_value - CTX.mpf(prev_val)):
raise ValueError('Rounded above instead of below',
mp_value, prev_val, float_val, next_val)
elif float_val < mp_value < next_val:
# float_val should be closer than next_val
# <==> 0 < mp_value - float_val < next_val - mp_value
if (mp_value - CTX.mpf(float_val) >=
CTX.mpf(next_val) - mp_value):
raise ValueError('Rounded below instead of above',
mp_value, prev_val, float_val, next_val)
else:
raise ValueError('Rounded outside of 2-ULP range',
mp_value, prev_val, float_val, next_val)
plt.plot(mpmath_points, relative_errs, color=HUSL_BLUE)
filename = 'sanity_check.png'
plt.savefig(filename, bbox_inches='tight')
print('Saved ' + filename)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment