Skip to content

Instantly share code, notes, and snippets.

@marcusmueller
Last active May 26, 2022 19:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save marcusmueller/c2a0c412a983632e326d29c11df01ebe to your computer and use it in GitHub Desktop.
Save marcusmueller/c2a0c412a983632e326d29c11df01ebe to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# Copyright 2022 Marcus Müller
# SPDX-License-Identifier: LGPL-3.0-or-later
import numpy
from scipy import optimize
from numpy import typing as np_t
def poly_approx(coeff_vec: np_t.ArrayLike,
variable: np_t.ArrayLike) -> np_t.ArrayLike:
coeffs = numpy.asarray(coeff_vec, dtype=numpy.float32)
variable = numpy.asarray(variable, dtype=numpy.float32)
return sum((variable**(idx * 2 + 1)).astype(numpy.float32) * coeff
for idx, coeff in enumerate(coeffs))
def odd_polynomial_residual_to_sin(coeff_vec: np_t.ArrayLike) -> float:
"""
This function calculates the sum of squared differences at 1024 points on [-pi;pi]
between the odd-only polynomial using the coefficients passed as arguments and the
numpy sine. We're making sure the approximate math is already `rounded` down to float32,
so that we don't approximate values we can't really use.
"""
x = numpy.linspace(-numpy.pi / 2, numpy.pi / 2, 2**12)
sine = numpy.sin(x)
poly_approx = sum(x**(idx * 2 + 1) * coeff
for idx, coeff in enumerate(coeff_vec))
return numpy.sum((poly_approx - sine)**2)
def odd_polynomial_Linfty_residual_to_sin(coeff_vec: np_t.ArrayLike) -> float:
"""
This function calculates the maximum squared difference over 1024 points on [-pi;pi]
between the odd-only polynomial using the coefficients passed as arguments and the
numpy sine. We're making sure the approximate math is already `rounded` down to float32,
so that we don't approximate values we can't really use.
"""
x = numpy.linspace(-numpy.pi / 2, numpy.pi / 2, 2**12)
sine = numpy.sin(x)
poly_approx = sum(x**(idx * 2 + 1) * coeff
for idx, coeff in enumerate(coeff_vec))
return max((poly_approx - sine)**2)
def runme():
# Use [0.5, 0.5, 0.5, 0.5] as initial guess, i.e., try to improve
# the sine approximation that 0.5x¹ + 0.5x³ + 0.5x⁵ + 0.5x⁷ is
# Adjust *4 to the number of coefficients you want to retain.
result_L2 = optimize.minimize(odd_polynomial_residual_to_sin, [0.5] * 2)
# print the result
print(r"c_{\text{L}_2} = " +
" + ".join(f"{coeff:.8g} \\cdot x^{{{2*k+1}}}"
for k, coeff in enumerate(result_L2.x)))
result_Linfty = optimize.minimize(odd_polynomial_Linfty_residual_to_sin,
[0.5] * 2)
print(r"c_{\text{L}_\infty} = " +
" + ".join(f"{coeff:.8g} \\cdot x^{{{2*k+1}}}"
for k, coeff in enumerate(result_Linfty.x)))
return result_L2, result_Linfty
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--plot-error", action="store_true")
parser.add_argument("-d", "--plot-error_relation", action="store_true")
parser.add_argument("-o", "--output")
args = parser.parse_args()
result_L2, result_Linfty = runme()
if args.plot_error:
from matplotlib import pyplot
try:
import seaborn
except:
pass
xrange = numpy.linspace(0, numpy.pi / 2, 2000)
fig, ax1 = pyplot.subplots(figsize=(10, 10), dpi=100)
ax1.set_xlabel(r"x/\pi")
ax1.set_xlim((min(xrange) / numpy.pi, max(xrange) / numpy.pi))
ax1.set_ylabel(r"\sin(x)-P(x)")
L2_error = numpy.sin(xrange) - poly_approx(result_L2.x, xrange)
Linfty_error = numpy.sin(xrange) - poly_approx(result_Linfty.x, xrange)
Raj_error = numpy.sin(xrange) - 0.98553 * xrange + 0.14257 * xrange**3
ax1.plot(xrange / numpy.pi, L2_error, label=r"L_2")
ax1.plot(xrange / numpy.pi, Linfty_error, label=r"L_\infty")
ax1.plot(xrange / numpy.pi, Raj_error, "--", label=r"Raj's")
ax1.legend(loc="upper left")
ax1.grid()
if args.plot_error_relation:
ax2 = ax1.twinx()
# ax2.plot(xrange / numpy.pi,
# 20 * numpy.log10(abs(L2_error / Linfty_error)),
# label="L2 vs L_infty")
ax2.plot(xrange / numpy.pi,
20 * numpy.log10(abs(Raj_error / Linfty_error)),
label="Raj's vs L_infty")
ax2.set_ylabel("first error / second error [dB]")
ax2.legend(loc="lower right")
ax2.grid()
fig.tight_layout()
if args.output:
pyplot.savefig(args.output)
else:
pyplot.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment