Skip to content

Instantly share code, notes, and snippets.

@bjodah
Forked from dhermes/.gitignore
Created December 12, 2022 06:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bjodah/2134e072e66916c74955268e245a0b8e to your computer and use it in GitHub Desktop.
Save bjodah/2134e072e66916c74955268e245a0b8e to your computer and use it in GitHub Desktop.
Remez Algorithm for log(x)
*.pyc
*.aux
*.log
*.out
from __future__ import print_function
import matplotlib.pyplot as plt
import mpmath
import numpy as np
import seaborn
# NOTE: Created via: seaborn.husl_palette(6)[4]
HUSL_BLUE = (0.23299120924703914, 0.639586552066035, 0.9260706093977744)
LOG2_HI = float.fromhex('0x1.62e42fee00000p-1')
LOG2_LOW = float.fromhex('0x1.a39ef35793c76p-33')
L1 = float.fromhex('0x1.5555555555593p-1')
L2 = float.fromhex('0x1.999999997fa04p-2')
L3 = float.fromhex('0x1.2492494229359p-2')
L4 = float.fromhex('0x1.c71c51d8e78afp-3')
L5 = float.fromhex('0x1.7466496cb03dep-3')
L6 = float.fromhex('0x1.39a09d078c69fp-3')
L7 = float.fromhex('0x1.2f112df3e5244p-3')
SQRT2_HALF = float.fromhex('0x1.6a09e667f3bcdp-1')
CTX = mpmath.MPContext()
CTX.prec = 200 # Bits vs. default of 53
def log_mpf(x):
return CTX.log(CTX.mpf(x))
def log_ieee754(x):
f1, ki = np.frexp(x)
if f1 < SQRT2_HALF:
f1 *= 2
ki -= 1
f = f1 - 1
k = float(ki)
s = f / (2 + f)
s2 = s * s
s4 = s2 * s2
# Terms with odd powers of s^2.
t1 = s2 * (L1 + s4 * (L3 + s4 * (L5 + s4 * L7)))
# Terms with even powers of s^2.
t2 = s4 * (L2 + s4 * (L4 + s4 * L6))
R = t1 + t2
hfsq = 0.5 * f * f
return k * LOG2_HI - ((hfsq - (s * (hfsq + R) + k * LOG2_LOW)) - f)
def main():
num_points = 2**14
x_vals = np.linspace(0, np.exp(8.0), num_points)
x_vals = x_vals[1:] # Eliminate 0.0 since log(0) = -inf.
log_rel_errors = []
for x_val in x_vals:
log_val = log_ieee754(x_val)
log_hp_val = log_mpf(x_val)
log_rel_errors.append(abs(log_val - log_hp_val) / abs(log_hp_val))
plt.plot(x_vals, log_rel_errors, color=HUSL_BLUE)
filename = 'log_high_precision_relative_error.png'
plt.savefig(filename, bbox_inches='tight')
print('Saved ' + filename)
if __name__ == '__main__':
main()
from __future__ import print_function
import itertools
import matplotlib.pyplot as plt
import mpmath
import seaborn
from remez_w_high_precision import get_peaks
# NOTE: Created via: seaborn.husl_palette(6)[4]
HUSL_BLUE = (0.23299120924703914, 0.639586552066035, 0.9260706093977744)
CTX = mpmath.MPContext()
CTX.prec = 200 # Bits vs. default of 53
L1 = CTX.mpf(float.fromhex('0x1.5555555555593p-1'))
L2 = CTX.mpf(float.fromhex('0x1.999999997fa04p-2'))
L3 = CTX.mpf(float.fromhex('0x1.2492494229359p-2'))
L4 = CTX.mpf(float.fromhex('0x1.c71c51d8e78afp-3'))
L5 = CTX.mpf(float.fromhex('0x1.7466496cb03dep-3'))
L6 = CTX.mpf(float.fromhex('0x1.39a09d078c69fp-3'))
L7 = CTX.mpf(float.fromhex('0x1.2f112df3e5244p-3'))
MAX_S = 3 - 2 * CTX.sqrt(2)
EXTRA_S = CTX.mpf('0.1717')
def f(s):
# (L1 * s**2 + L2 * s**4 + L3 * s**6 + L4 * s**8 +
# L5 * s**10 + L6 * s**12 + L7 * s**14)
# Computes f(s) the exact same way is done in math/log.go.
s2 = s * s
s4 = s2 * s2
t1 = s2 * (L1 + s4 * (L3 + s4 * (L5 + s4 * L7)))
t2 = s4 * (L2 + s4 * (L4 + s4 * L6))
return t1 + t2
def R(s):
if s == CTX.mpf(0.0):
return CTX.mpf(0.0)
else:
numer = CTX.log(1 + s) - CTX.log(1 - s)
return numer / s - CTX.mpf(2.0)
def err_func(s):
return R(s) - f(s)
def err_func_diff(s):
return CTX.diff(err_func, s)
def main():
num_points = 2**10
extra_points = 32
# Extend past the true max to find an 8th point where
# equi-oscillation occurs.
s_vals = (CTX.linspace(0, MAX_S, num_points) +
CTX.linspace(MAX_S, EXTRA_S, extra_points)[1:])
delta_vals = [err_func(s_val) for s_val in s_vals]
abs_vals_iter = itertools.imap(abs, delta_vals)
exponent = CTX.log(max(abs_vals_iter), 2)
print(' Number of points: % d' % (num_points,))
print(' CTX.prec: % d' % (CTX.prec,))
print('Exponent of max error: % 2.3f' % (exponent,))
peaks = sorted(get_peaks(s_vals, delta_vals, 8))
# Move from the fixed grid (given by ``s_vals``) onto the entire
# interval by finding critical points nearby. We **don't** do this
# for the biggest peak x-value since the function has no critical
# point on the outside.
new_peaks = []
for peak in peaks[:-1]:
new_peaks.append(CTX.findroot(err_func_diff, peak))
new_peaks.append(peaks[-1])
# Over-write the peaks.
peaks = new_peaks
delta_peaks = [err_func(peak) for peak in peaks]
plt.plot(s_vals, delta_vals, color=HUSL_BLUE)
plt.plot(peaks, delta_peaks, color='black', linestyle='None',
marker='o', markersize=4)
avg_peak_val = sum(itertools.imap(abs, delta_peaks)) / len(delta_peaks)
plt.plot(s_vals, [avg_peak_val] * len(s_vals),
linestyle='dashed', color='black')
plt.plot(s_vals, [-avg_peak_val] * len(s_vals),
linestyle='dashed', color='black')
plt.xlabel('$s$')
plt.title('$R(s) - f(s)$')
filename = 'remez_equioscillating_error.png'
plt.gcf().patch.set_alpha(0.0)
plt.savefig(filename, bbox_inches='tight')
print('Saved ' + filename)
print('-' * 40)
print('Peaks occur at x-values:')
for peak in peaks:
peak_as_double = float(str(peak))
print('- ' + peak_as_double.hex())
if __name__ == '__main__':
main()
Number of points: 1024
CTX.prec: 200
Exponent of max error: -58.472
Saved remez_equioscillating_error.png
----------------------------------------
Peaks occur at x-values:
- 0x1.dda01e9bda24ap-6
- 0x1.05cc9437a79e0p-4
- 0x1.82e6cfe518d80p-4
- 0x1.ef48b9760dac4p-4
- 0x1.23e3b6364b181p-3
- 0x1.44ac2edbbe0cfp-3
- 0x1.58d6936b415e4p-3
- 0x1.5fa43fe5c91d1p-3
from __future__ import print_function
import matplotlib.pyplot as plt
import mpmath
import numpy as np
import seaborn
# NOTE: Created via: seaborn.husl_palette(6)[4]
HUSL_BLUE = (0.23299120924703914, 0.639586552066035, 0.9260706093977744)
DESIRED_PRECISION = 200 # Bits vs. default of 53
MAX_S = 3 - 2 * np.sqrt(2)
L1 = float.fromhex('0x1.5555555555593p-1')
L2 = float.fromhex('0x1.999999997fa04p-2')
L3 = float.fromhex('0x1.2492494229359p-2')
L4 = float.fromhex('0x1.c71c51d8e78afp-3')
L5 = float.fromhex('0x1.7466496cb03dep-3')
L6 = float.fromhex('0x1.39a09d078c69fp-3')
L7 = float.fromhex('0x1.2f112df3e5244p-3')
CTX = mpmath.MPContext()
CTX.prec = 200 # Bits vs. default of 53
def f(s):
# (L1 * s**2 + L2 * s**4 + L3 * s**6 + L4 * s**8 +
# L5 * s**10 + L6 * s**12 + L7 * s**14)
# Computes f(s) the exact same way is done in math/log.go.
s2 = s * s
s4 = s2 * s2
t1 = s2 * (L1 + s4 * (L3 + s4 * (L5 + s4 * L7)))
t2 = s4 * (L2 + s4 * (L4 + s4 * L6))
return t1 + t2
def f_mpf(s):
# Same as f() but in high-precision.
return f(CTX.mpf(s))
def R_mpf(s):
s = CTX.mpf(s)
if s == CTX.mpf(0.0):
return CTX.mpf(0.0)
else:
numer = CTX.log(1 + s) - CTX.log(1 - s)
return numer / s - CTX.mpf(2.0)
def R(s):
return float(str(R_mpf(s)))
def main():
num_points = 2**10
s_vals = np.linspace(0, MAX_S, num_points)
f_error_vals = []
R_error_vals = []
approx_error_vals = []
approx_error_hp_vals = []
for s_val in s_vals:
# NOTE: Many computations are wasted / performed multiple times.
f_val = f(s_val)
f_hp_val = f_mpf(s_val) # f_hp == f in high-precision
f_error_vals.append(f_hp_val - f_val)
R_val = R(s_val)
R_hp_val = R_mpf(s_val) # R_hp == R in high-precision
# NOTE: Could check that float(str(R_mpf(s_val))) == R_val.
R_error_vals.append(R_hp_val - R_val)
# Now save the **actual errors**.
approx_error_vals.append(R_val - f_val)
approx_error_hp_vals.append(R_hp_val - f_hp_val)
max_err = np.max([np.max(f_error_vals),
np.max(R_error_vals),
np.max(approx_error_vals),
np.max(approx_error_hp_vals)])
min_err = np.min([np.min(f_error_vals),
np.min(R_error_vals),
np.min(approx_error_vals),
np.min(approx_error_hp_vals)])
err_height = max_err - min_err
err_mid = 0.5 * (max_err + min_err)
# H/T: http://stackoverflow.com/a/16542886/1068170
plt.subplot(221)
plt.plot(s_vals, f_error_vals, color=HUSL_BLUE)
plt.ylim(err_mid - 0.55 * err_height, err_mid + 0.55 * err_height)
plt.title(r'$f^{hp}(s) - f(s)$')
plt.subplot(222)
plt.plot(s_vals, R_error_vals, color=HUSL_BLUE)
plt.ylim(err_mid - 0.55 * err_height, err_mid + 0.55 * err_height)
plt.title(r'$R^{hp}(s) - R(s)$')
plt.subplot(212)
plt.plot(s_vals, approx_error_vals, color=HUSL_BLUE)
plt.plot(s_vals, approx_error_hp_vals,
linestyle='dashed', linewidth=0.5, color='black')
plt.ylim(err_mid - 0.55 * err_height, err_mid + 0.55 * err_height)
plt.title('$R(s) - f(s)$')
filename = 'remez_equioscillating_error_ieee754.png'
curr_fig = plt.gcf()
width, height = curr_fig.get_size_inches()
curr_fig.set_size_inches(1.5 * width, height)
plt.savefig(filename, bbox_inches='tight')
print('Saved ' + filename)
if __name__ == '__main__':
main()
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
\documentclass[letterpaper,10pt]{article}
\usepackage[margin=1in]{geometry}
\usepackage{amsthm,amssymb,amsmath,hyperref}
\hypersetup{colorlinks=true,urlcolor=blue,citecolor=blue}
\usepackage{embedfile}
\embedfile{\jobname.tex}
\usepackage{fancyhdr}
\pagestyle{fancy}
\lhead{}
\rhead{Remez Approximation}
\renewcommand{\headrulewidth}{0pt}
\begin{document}
\section{Background}
The Golang \href{https://golang.org/src/math/log.go}{source} for
the \verb+log()+
function\footnote{The Golang source actually mentions that this method and some of the comments were borrowed from original C code, from FreeBSD's \texttt{/usr/src/lib/msun/src/e\_log.c}}
%% /usr/src/lib/msun/src/e_log.c
states that the natural log is computed in three parts
\subsection{Argument Reduction}
Find \(k\) and \(f\) such that
\[x = 2^k \left(1 + f\right)\]
where \(\frac{\sqrt{2}}{2} < 1 + f < \sqrt{2}\).
\subsection{Approximation of \texorpdfstring{$\log(1+f)$}{log(1+f)}}
Let \(s = \frac{f}{2 + f}\), based on
\begin{align*}
\log(1 + f) &= \log(1 + s) - \log(1 - s) \\
&= 2s + \frac{2}{3} s^3 + \frac{2}{5} s^5 + \cdots \\
&= 2s + s R\left(s\right)
\end{align*}
Use a special
Reme\footnote{misspelling, should be \href{https://en.wikipedia.org/wiki/Remez_algorithm}{Remez}}
algorithm on \(\left[0, 0.1716\right]\) to generate a polynomial of
degree \(14\) to approximate \(R\).
This approximation is given
by\footnote{the \(z\) is likely \(s^2\) which likely means the polynomial is intentionally degree \(7\) in \(z\)}
\[R(z) \approx L_1 s^2 + L_2 s^4 + L_3 s^6 + L_4 s^8 + L_5 s^{10} +
L_6 s^{12} + L_7 s^{14}\]
with maximum error \(2^{-58.45}\).
\subsection{Combining}
Finally
\begin{align*}
\log(x) &= k \log(2) + \log(1 + f) \\
&= k H + \left(f - \left(\frac{f^2}{2} - \left(s \left(
\frac{f^2}{2} + R\right) + k L\right)\right)\right)
\end{align*}
where
\[\log(2) = H + L\]
splits as high and low parts
\begin{align*}
H &= 2^{-1} \cdot \mathtt{1.62e42fee00000}_{16} \\
L &= 2^{-1} \cdot \mathtt{0.00000001a39ef35793c76}_{16} =
2^{-33} \cdot \mathtt{1.a39ef35793c76}_{16} \\
\log(2) &= 2^{-1} \cdot \mathtt{1.62e42fefa39ef}_{16}
\end{align*}
such that the high part satisfies \(nH\) is always exact for
\(\left|n\right| < 2000\).
\section{Reducing Value}
First note that
\[2^{-1} < 2^{-\frac{1}{2}} < 2^0 < 2^{\frac{1}{2}} < 2^1
\quad \text{i.e.} \quad
\frac{1}{2} < \frac{\sqrt{2}}{2} < 1 < \sqrt{2} < 2.\]
In Golang the \href{https://golang.org/src/math/frexp.go}{function}
\verb+math.Frexp+ takes a float \(x\) and returns an exponent
\(e\) and fractional value \(m\) such that
\[x = 2^e m, \qquad 2\left|m\right| \in \left[1, 2\right).\]
Since \(\log(x)\) defined requires \(x > 0\) we know \(m > 0\).
If \(\frac{1}{\sqrt{2}} < m < 1\) then we set \(k = e\) and
\(1 + f = m\). If \(\frac{1}{2} \leq m \leq \frac{1}{\sqrt{2}}\)
then we set \(1 + f = 2m \in \left[1, \sqrt{2}\right]\) and
\(k = e - 1\) since \(2^e m = 2^{e - 1}(2m)\). This would appear
to only give
\[\frac{1}{\sqrt{2}} < 1 + f \leq \sqrt{2}\]
but we also know \(1 + f \neq \sqrt{2}\) since an irrational
can't be represented in floating point.
\section{From \texorpdfstring{$f$}{f} to \texorpdfstring{$s$}{s}}
We write \(1 + f = \displaystyle\frac{1 + s}{1 - s}\) so that
\begin{align*}
\log(1 + f) &= \log(1 + s) - \log(1 - s) \\
&= 2s + \frac{2}{3} s^3 + \frac{2}{5} s^5 + \cdots
\end{align*}
Solving for \(s\) in the above yields
\(s = \displaystyle\frac{f}{2 + f} = 1 - \frac{2}{2 + f}\) and
\[\frac{1}{\sqrt{2}} < 1 + f < \sqrt{2} \Longrightarrow
-\left(3 - 2\sqrt{2}\right) < s < 3 - 2 \sqrt{2} \approx
0.171573 < 0.1716.\]
\subsection{Why the bounds?}
The bound \(\frac{1}{\sqrt{2}} < 1 + f < \sqrt{2}\) was just given but
never justified. In fact, this interval is required to be able to
use an approximation of \(R(s)\). First, in
order to factor \(x = 2^k (1 + f)\) uniquely, if \(1 + f \in
\left[\alpha, \beta\right]\) we must have
\[\log_2 \beta = \log_2 \alpha + 1 \Longleftrightarrow
\beta = 2 \alpha.\]
Second, since defining \(R(s)\) requires both \(\log(1 \pm s)\)
to be defined, we need \(-1 < s < 1\).
Third, the symmetry \(R(s) = R(-s)\) means we limit to some
\(-A < s < A\) (for \(0 < A \leq 1\)). This in turn means that
\[\frac{1}{B} < 1 + f < B, \qquad B =
\frac{1 + A}{1 - A} \Longleftrightarrow A = \frac{B - 1}{B + 1}.\]
Hence we put \(1 + f \in \left[B^{-1}, B\right]\) which
forces
\[\log_2 B = \log_2\left(B^{-1}\right) + 1
\Longrightarrow 2 \log_2 B = 1 \Longrightarrow B =
2^{\frac{1}{2}} = \sqrt{2} \Longrightarrow A = 3 - 2 \sqrt{2}.\]
\section{Remez Algorithm}
We seek to approximate
\[R(s) = \frac{2s^2}{3} + \frac{2s^4}{5} + \frac{2s^6}{7} + \cdots =
\frac{\log(1 + s) - \log(1 - s)}{s} - 2, \quad
R(0) = 0\]
with a degree \(14\) polynomial that gives equi-oscillating errors.
By construction \(R(s) = R(-s)\), hence we need to be able to approximate
\(R(s)\) for \(s \in \left[0, 3 - 2 \sqrt{2}\right] \subset
\left[0, 0.1716\right]\).
We manually force an equi-oscillating error at node points
\(s_0, s_1, \ldots, s_7\) by setting
\[R(s_j) = L_2 s_j^2 + L_4 s_j^4 + \cdots + L_{14} s_j^{14} + (-1)^j E
= P(s_j) \pm E\]
and solving for the \(7\) unknown coefficients and the error
\(E\). This gives a system
\[\left[\begin{array}{c c c c}
s_0^2 & \cdots & s_0^{14} & 1 \\
s_1^2 & \cdots & s_1^{14} & -1 \\
\vdots & \multicolumn{2}{c}{\ddots} & \vdots \\
s_7^2 & \cdots & s_7^{14} & -1
\end{array}\right]
\left[\begin{array}{c}
L_2 \\ \vdots \\ L_{14} \\ E
\end{array}\right] =
\left[\begin{array}{c}
R(s_0) \\ R(s_1) \\ \vdots \\ R(s_7)
\end{array}\right].\]
After solving, we exchange
\(s_0, s_1, \ldots, s_7\) for new points
which maximize \(\left|R(s) - P(s)\right|\) locally.
The method terminates once the exchange process results
in no change at all (or a minimal change).
An alternative approach only swaps a single \(s_j\)
at a time. It finds the absolute extreme
\[s^{\ast} = \underset{{s \in \left[0, 3 - 2 \sqrt{2}\right]}}{
\operatorname{argmax}}
\left|R(s) - P(s)\right|\]
and then swap \(s^{\ast}\) with the nearest value among the
\(\left\{s_j\right\}\) and only terminates once the absolute
extreme occurs among the \(\left\{s_j\right\}\).
In either situation, if there is no exchange left to do
\[\max_{s \in \left[0, 3 - 2 \sqrt{2}\right]}
\left|R(s) - P(s)\right| = \left|R\left(s^{\ast}\right) -
P\left(s^{\ast}\right)\right| = \left|\pm E\right| = E\]
and this equi-oscillating error will occur a maximal number of
times in our interval.
\end{document}
Difference between successive x-vectors: 2.78724e-09
Completed in 7 steps
Coefficients:
- 0x1.5555555555593p-1
0x1.5555555555593p-1
- 0x1.999999997f9f8p-2
0x1.999999997f...p-2
- 0x1.249249422a440p-2
0x1.249249422....p-2
- 0x1.c71c51d7cf382p-3
0x1.c71c51d......p-3
- 0x1.746649afb0e69p-3
0x1.746649.......p-3
- 0x1.39a095848f9a5p-3
0x1.39a09........p-3
- 0x1.2f117fc8e24c3p-3
0x1.2f11.........p-3
Saved remez_approx.png
from __future__ import print_function
import matplotlib.pyplot as plt
import mpmath
import seaborn
# NOTE: Created via: seaborn.husl_palette(6)[4]
HUSL_BLUE = (0.23299120924703914, 0.639586552066035, 0.9260706093977744)
EXPECTED_COEFFS = [
'0x1.5555555555593p-1',
'0x1.999999997fa04p-2',
'0x1.2492494229359p-2',
'0x1.c71c51d8e78afp-3',
'0x1.7466496cb03dep-3',
'0x1.39a09d078c69fp-3',
'0x1.2f112df3e5244p-3',
]
EXPONENTS = (2, 4, 6, 8, 10, 12, 14)
SIZE_INTERVAL = 512
NUM_POINTS = 8
MAX_STEPS = 20
CTX = mpmath.MPContext()
CTX.prec = 200 # Bits vs. default of 53
# NOTE: Slightly larger than 3 - 2 * mpmath.sqrt(2)
# so that we can capture an 8th equi-oscillating
# point.
MAX_X = CTX.mpf('0.1717')
INTERVAL_SAMPLE = CTX.linspace(0, MAX_X, SIZE_INTERVAL)
def exp_by_squaring(val, n):
"""Exponentiate by squaring.
:type val: float
:param val: A value to exponentiate.
:type n: int
:param n: The (positive) exponent.
:rtype: float
:returns: The exponentiated value.
"""
result = 1
pow_val = val
while n != 0:
n, remainder = divmod(n, 2)
if remainder == 1:
result *= pow_val
pow_val = pow_val * pow_val
return result
def R_scalar(s):
"""Compute the value of R(s).
R(s) is function such that
log(1 + s) - log(1 - s) = s(2 + R(s))
Uses ``mpmath`` to compute the answer in high precision.
:type s: float
:param s: A scalar to evaluate ``R(s)`` at.
:rtype: float
:returns: The value of ``R(s)`` at our input value.
"""
if s == 0.0:
# We can't divide by 0, but we know that
# R(0) = 2 * 0^2/3 + 2 * 0^4 / 5 + ... = 0.
return CTX.mpf(0.0)
numer = CTX.log(1 + s) - CTX.log(1 - s)
return numer / s - CTX.mpf(2.0)
def update_remez_poly(x_vals):
"""Updates the Remez polynomial coefficients.
:type x_vals: list
:param x_vals: The ``x``-values where equi-oscillating occurs.
We assume there are ``NUM_POINTS`` values.
:rtype: tuple
:returns: A pair, the first is the coefficients (as a list) and the second
is a scalar (the equi-oscillating error).
"""
# Columns correspond to an exponent while rows correspond to
# an x-value.
coeff_system = CTX.matrix(NUM_POINTS, NUM_POINTS)
rhs = CTX.matrix(NUM_POINTS, 1)
for row in range(NUM_POINTS):
# Handle the final column first (just a sign).
coeff_system[row, NUM_POINTS - 1] = (-1)**row
x_val = x_vals[row]
rhs[row, 0] = R_scalar(x_val)
# Handle all columns left (final column already done).
for col in range(NUM_POINTS - 1):
pow_ = EXPONENTS[col]
coeff_system[row, col] = exp_by_squaring(x_val, pow_)
soln = CTX.lu_solve(coeff_system, rhs, real=True)
# Turn into a row vector and then turn it into a list.
soln, = soln.T.tolist()
return soln[:-1], soln[-1]
def get_chebyshev_points(num_points):
"""Get Chebyshev points for [0, MAX_X].
:type num_points: int
:param num_points: The number of points to use.
:rtype: list
:returns: The Chebyshev points for our interval.
"""
result = []
for index in range(2 * num_points - 1, 0, -2):
theta = CTX.pi * index / CTX.mpf(2.0 * num_points)
result.append(0.5 * MAX_X * (1 + CTX.cos(theta)))
return result
class ErrorFunc(object):
"""Error function for representing R(s) and P(s).
Takes a given set of ``x``-values, computes the necessary
coefficients of ``P(s)`` with them and uses them to define
the function R(s) - P(s).
:type x_vals: list
:param x_vals: A 1D array. The ``x``-values where
equi-oscillating occurs.
"""
def __init__(self, x_vals):
self.x_vals = x_vals
# Computed values.
self.poly_coeffs, self.E = update_remez_poly(x_vals)
def poly_approx_scalar(self, value):
"""Evaluate the polynomial :math:`f(x)`.
Uses the same method as in ``math/log.go`` to compute
.. math::
L_1 x^2 + L_2 x^4 + \\cdots + L_7 x^{14}
:type value: float
:param value: The value to compute ``f(x)`` at.
:rtype: float
:returns: The value of ``f(x)`` at our input value.
"""
L1 = self.poly_coeffs[0]
L2 = self.poly_coeffs[1]
L3 = self.poly_coeffs[2]
L4 = self.poly_coeffs[3]
L5 = self.poly_coeffs[4]
L6 = self.poly_coeffs[5]
L7 = self.poly_coeffs[6]
s2 = value * value
s4 = s2 * s2
t1 = s2 * (L1 + s4 * (L3 + s4 * (L5 + s4 * L7)))
t2 = s4 * (L2 + s4 * (L4 + s4 * L6))
return t1 + t2
def signed_error_scalar(self, value):
"""Error function R(s) - P(s).
:type value: float
:param value: An ``x``-value.
:rtype: float
:returns: The value of the error R(s) - P(s).
"""
return R_scalar(value) - self.poly_approx_scalar(value)
def signed_error_diff(self, value):
"""Derivative of error function R'(s) - P'(s).
:type value: float
:param value: An ``x``-value.
:rtype: float
:returns: The value of the error R'(s) - P'(s).
"""
return CTX.diff(self.signed_error_scalar, value)
def signed_error(self, values):
"""Error function R(s) - P(s).
:type values: list
:param values: A list of ``x``-values.
:rtype: list
:returns: The value of the error R(s) - P(s) at each point.
"""
result = []
for val in values:
result.append(self.signed_error_scalar(val))
return result
def locate_abs_max(values):
"""Locate the absolute maximum of a list of values.
:type values: list
:param values: A list of scalar values.
:rtype: int
:returns: The index where the maximum occurs.
"""
curr_max = -CTX.inf
curr_max_index = -1
for index, value in enumerate(values):
abs_val = abs(value)
if abs_val > curr_max:
curr_max = abs_val
curr_max_index = index
return curr_max_index
def get_peaks(x_data, y_data, num_peaks):
"""Get the peaks from oscillating output data.
:type x_data: list
:param x_data: The ``x``-values where the outputs occur.
:type y_data: list
:param y_data: The oscillating output data.
:type num_peaks: int
:param num_peaks: The number of peaks to locate.
:rtype: list
:returns: The ``x``-locations of all the peaks that were found,
in the order that they were found.
"""
local_data = y_data[::]
size_interval = len(y_data)
peak_locations = []
while len(peak_locations) < num_peaks:
curr_biggest = locate_abs_max(local_data)
x_now = x_data[curr_biggest]
if x_now in peak_locations:
raise ValueError('Repeat value found.')
peak_locations.append(x_now)
# Find the sign so we can identify all nearby points on the
# peak (they will have the same sign).
sign_x_now = CTX.sign(local_data[curr_biggest])
local_data[curr_biggest] = 0.0
# Zero out all the values on the same peak to the
# right of x_now.
index = curr_biggest + 1
while (index < size_interval and
CTX.sign(local_data[index]) == sign_x_now):
local_data[index] = 0.0
index += 1
# Zero out all the values on the same peak to the
# left of x_now.
index = curr_biggest - 1
while (index >= 0 and
CTX.sign(local_data[index]) == sign_x_now):
local_data[index] = 0.0
index -= 1
return peak_locations
def get_new_x_vals(x_vals, sample_points=INTERVAL_SAMPLE):
"""Perform single pass of Remez algorithm.
First computes the coefficients based on ``x_vals``, then locates
the extrema of ``|P(s) - R(s)|``.
:type x_vals: list
:param x_vals: The ``x``-values where equi-oscillating occurs.
:type sample_points: list
:param sample_points: (Optional) The points we choose extrema from.
:rtype: list
:returns: The new ``x``-values.
"""
err_func = ErrorFunc(x_vals)
size_interval = len(sample_points)
approx_outputs = err_func.signed_error(sample_points)
max_vals = get_peaks(sample_points, approx_outputs, NUM_POINTS)
max_vals = sorted(max_vals)
# Move from the fixed grid (given by ``sample_points``) onto the entire
# interval by finding critical points nearby. We **don't** do this
# for the biggest max value since the function has no critical point on
# the outside.
new_vals = []
for val in max_vals[:-1]:
new_vals.append(CTX.findroot(err_func.signed_error_diff, val))
new_vals.append(max_vals[-1])
# NOTE: We could / should check that new_vals == sorted(new_vals)
# and that it has no repeats.
return new_vals
def plot_x_vals(x_vals, sample_points, filename=None):
"""Plot the error R(s) - P(s) for P(s) given by ``x``-values.
Also plots the locations of each equi-oscillating ``x``-value on
the curve.
:type x_vals: list
:param x_vals: The ``x``-values where equi-oscillating occurs.
:type sample_points: list
:param sample_points: The points we choose extrema from.
:type filename: str
:param filename: (Optional) The filename to save the plot in. If
not specified, just shows the plot.
"""
err_func = ErrorFunc(x_vals)
approx_outputs = err_func.signed_error(sample_points)
plt.plot(sample_points, approx_outputs, color=HUSL_BLUE)
plt.plot(x_vals, err_func.signed_error(x_vals), marker='o',
color='black', linestyle='None')
if filename is None:
plt.show()
else:
plt.savefig(filename, bbox_inches='tight')
print('Saved ' + filename)
def _list_delta_norm(list1, list2):
return CTX.norm(CTX.matrix(list1) - CTX.matrix(list2), p=2)
def _lead_substr_match(val1, val2):
min_len = min(len(val1), len(val2))
for index in range(1, min_len + 1):
if val1[:index] != val2[:index]:
return index - 1
return min_len
def _print_double_hex_compare(actual_mpf, expected):
# Convert back to double, then to hex.
actual_as_hex = float(str(actual_mpf)).hex()
print('- ' + actual_as_hex)
expected_begin, expected_expon = expected.split('p', 1)
match_index = _lead_substr_match(actual_as_hex,
expected_begin)
buffer = '.' * (len(expected_begin) - match_index)
print(' ' + expected_begin[:match_index] +
buffer + 'p' + expected_expon)
def main(sample_points=INTERVAL_SAMPLE,
threshold=CTX.mpf(2)**(-26), plot_all=False):
"""Run the Remez algorithm until termination.
:type sample_points: list
:param sample_points: (Optional) The points we choose extrema from.
:type threshold: float
:param threshold: (Optional) The minimum value of the norm of the
difference between current ``x``-values and the
next set of ``x``-values found via the Remez
algorithm. Defaults to ``2^{-26}``.
:type plot_all: bool
:param plot_all: (Optional) Flag indicating all updates should be
plotted. Defaults to :data:`False`.
"""
prev_x_vals = [CTX.inf] * NUM_POINTS
x_vals = get_chebyshev_points(NUM_POINTS)
num_steps = 0
while _list_delta_norm(x_vals, prev_x_vals) > threshold:
if num_steps >= MAX_STEPS:
print('Max. steps encountered. Does not converge.')
break
if plot_all:
plot_x_vals(x_vals, sample_points)
prev_x_vals = x_vals
x_vals = get_new_x_vals(x_vals, sample_points=sample_points)
num_steps += 1
norm_update = _list_delta_norm(x_vals, prev_x_vals)
msg = 'Difference between successive x-vectors: %g' % (norm_update,)
print(msg)
msg = 'Completed in %d steps' % (num_steps,)
print(msg)
err_func = ErrorFunc(x_vals)
print('Coefficients:')
for index, coeff in enumerate(err_func.poly_coeffs):
_print_double_hex_compare(coeff, EXPECTED_COEFFS[index])
plot_x_vals(x_vals, sample_points, filename='remez_approx.png')
if __name__ == '__main__':
main()
"""Simple sanity check that mpmath.mpf() -> float works.
Checks the relative error in extended precision of the simple
process
mpmath.mpf('a.bcdef...') --> 'a.bcdef...' --> float('a.bcdef...')
and gives an idea of how correct the output is.
This check was inspired by a bug that I wrote which used an mpf in
the numerator and a standard Python float in the denominator,
causing unexpected results.
"""
from __future__ import print_function
import matplotlib.pyplot as plt
import mpmath
import numpy as np
import seaborn
# NOTE: Created via: seaborn.husl_palette(6)[4]
HUSL_BLUE = (0.23299120924703914, 0.639586552066035, 0.9260706093977744)
ULP_SIZE = 2.0**52 # NOTE: This assumes a 52-bit mantissa.
CTX = mpmath.MPContext()
CTX.prec = 200 # Bits vs. default of 53
def prev_float(val):
"""Gets the next float after val in IEEE 754 floating point.
>>> prev_float(float.fromhex('-0x1.fffffffffffffp+0')).hex()
'-0x1.0000000000000p+1'
>>> prev_float(float.fromhex('-0x1.0000000000000p+0')).hex()
'-0x1.0000000000001p+0'
>>> prev_float(float.fromhex('-0x1.626262626262fp-1')).hex()
'-0x1.6262626262630p-1'
>>> prev_float(float.fromhex('-0x0.0000000000001p-1022')).hex()
'-0x0.0000000000002p-1022'
>>> prev_float(float.fromhex('0x0.0p+0')).hex()
'-0x0.0000000000001p-1022'
>>> prev_float(float.fromhex('0x0.0000000000001p-1022')).hex()
'0x0.0p+0'
>>> prev_float(float.fromhex('0x1.0000000000000p+0')).hex()
'0x1.fffffffffffffp-1'
>>> prev_float(float.fromhex('0x1.fffffffffffffp+0')).hex()
'0x1.ffffffffffffep+0'
>>> prev_float(float.fromhex('0x1.bbbbbbbbbbbbap+102')).hex()
'0x1.bbbbbbbbbbbb9p+102'
"""
if val < 0:
return - next_float(-val)
else:
spacing = np.spacing(val)
if spacing * ULP_SIZE == val:
# In this case, val is a power of 2.
return val - 0.5 * spacing
else:
return val - spacing
def next_float(val):
"""Gets the next float after val in IEEE 754 floating point.
>>> next_float(float.fromhex('-0x1.fffffffffffffp+0')).hex()
'-0x1.ffffffffffffep+0'
>>> next_float(float.fromhex('-0x1.0000000000000p+0')).hex()
'-0x1.fffffffffffffp-1'
>>> next_float(float.fromhex('-0x1.626262626262fp-1')).hex()
'-0x1.626262626262ep-1'
>>> next_float(float.fromhex('-0x0.0000000000001p-1022')).hex()
'-0x0.0p+0'
>>> next_float(float.fromhex('0x0.0p+0')).hex()
'0x0.0000000000001p-1022'
>>> next_float(float.fromhex('0x0.0000000000001p-1022')).hex()
'0x0.0000000000002p-1022'
>>> next_float(float.fromhex('0x1.0000000000000p+0')).hex()
'0x1.0000000000001p+0'
>>> next_float(float.fromhex('0x1.fffffffffffffp+0')).hex()
'0x1.0000000000000p+1'
>>> next_float(float.fromhex('0x1.bbbbbbbbbbbbap+102')).hex()
'0x1.bbbbbbbbbbbbbp+102'
"""
if val < 0:
return - prev_float(-val)
else:
return val + np.spacing(val)
def main():
num_points = 2**14
mpmath_points = CTX.linspace(0, 1000, num_points)
relative_errs = []
for mp_value in mpmath_points:
float_val = float(str(mp_value))
prev_val = prev_float(float_val)
next_val = next_float(float_val)
if mp_value == float_val:
relative_errs.append(0.0)
else:
relative_errs.append((float_val - mp_value) / mp_value)
if prev_val < mp_value < float_val:
# float_val should be closer than prev_val
# <==> 0 < float_val - mp_value < mp_value - prev_val
if (CTX.mpf(float_val) - mp_value >=
mp_value - CTX.mpf(prev_val)):
raise ValueError('Rounded above instead of below',
mp_value, prev_val, float_val, next_val)
elif float_val < mp_value < next_val:
# float_val should be closer than next_val
# <==> 0 < mp_value - float_val < next_val - mp_value
if (mp_value - CTX.mpf(float_val) >=
CTX.mpf(next_val) - mp_value):
raise ValueError('Rounded below instead of above',
mp_value, prev_val, float_val, next_val)
else:
raise ValueError('Rounded outside of 2-ULP range',
mp_value, prev_val, float_val, next_val)
plt.plot(mpmath_points, relative_errs, color=HUSL_BLUE)
filename = 'sanity_check.png'
plt.savefig(filename, bbox_inches='tight')
print('Saved ' + filename)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment