Last active
May 26, 2022 19:36
-
-
Save marcusmueller/c2a0c412a983632e326d29c11df01ebe to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# Copyright 2022 Marcus Müller | |
# SPDX-License-Identifier: LGPL-3.0-or-later | |
import numpy | |
from scipy import optimize | |
from numpy import typing as np_t | |
def poly_approx(coeff_vec: np_t.ArrayLike, | |
variable: np_t.ArrayLike) -> np_t.ArrayLike: | |
coeffs = numpy.asarray(coeff_vec, dtype=numpy.float32) | |
variable = numpy.asarray(variable, dtype=numpy.float32) | |
return sum((variable**(idx * 2 + 1)).astype(numpy.float32) * coeff | |
for idx, coeff in enumerate(coeffs)) | |
def odd_polynomial_residual_to_sin(coeff_vec: np_t.ArrayLike) -> float: | |
""" | |
This function calculates the sum of squared differences at 1024 points on [-pi;pi] | |
between the odd-only polynomial using the coefficients passed as arguments and the | |
numpy sine. We're making sure the approximate math is already `rounded` down to float32, | |
so that we don't approximate values we can't really use. | |
""" | |
x = numpy.linspace(-numpy.pi / 2, numpy.pi / 2, 2**12) | |
sine = numpy.sin(x) | |
poly_approx = sum(x**(idx * 2 + 1) * coeff | |
for idx, coeff in enumerate(coeff_vec)) | |
return numpy.sum((poly_approx - sine)**2) | |
def odd_polynomial_Linfty_residual_to_sin(coeff_vec: np_t.ArrayLike) -> float: | |
""" | |
This function calculates the maximum squared difference over 1024 points on [-pi;pi] | |
between the odd-only polynomial using the coefficients passed as arguments and the | |
numpy sine. We're making sure the approximate math is already `rounded` down to float32, | |
so that we don't approximate values we can't really use. | |
""" | |
x = numpy.linspace(-numpy.pi / 2, numpy.pi / 2, 2**12) | |
sine = numpy.sin(x) | |
poly_approx = sum(x**(idx * 2 + 1) * coeff | |
for idx, coeff in enumerate(coeff_vec)) | |
return max((poly_approx - sine)**2) | |
def runme(): | |
# Use [0.5, 0.5, 0.5, 0.5] as initial guess, i.e., try to improve | |
# the sine approximation that 0.5x¹ + 0.5x³ + 0.5x⁵ + 0.5x⁷ is | |
# Adjust *4 to the number of coefficients you want to retain. | |
result_L2 = optimize.minimize(odd_polynomial_residual_to_sin, [0.5] * 2) | |
# print the result | |
print(r"c_{\text{L}_2} = " + | |
" + ".join(f"{coeff:.8g} \\cdot x^{{{2*k+1}}}" | |
for k, coeff in enumerate(result_L2.x))) | |
result_Linfty = optimize.minimize(odd_polynomial_Linfty_residual_to_sin, | |
[0.5] * 2) | |
print(r"c_{\text{L}_\infty} = " + | |
" + ".join(f"{coeff:.8g} \\cdot x^{{{2*k+1}}}" | |
for k, coeff in enumerate(result_Linfty.x))) | |
return result_L2, result_Linfty | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-p", "--plot-error", action="store_true") | |
parser.add_argument("-d", "--plot-error_relation", action="store_true") | |
parser.add_argument("-o", "--output") | |
args = parser.parse_args() | |
result_L2, result_Linfty = runme() | |
if args.plot_error: | |
from matplotlib import pyplot | |
try: | |
import seaborn | |
except: | |
pass | |
xrange = numpy.linspace(0, numpy.pi / 2, 2000) | |
fig, ax1 = pyplot.subplots(figsize=(10, 10), dpi=100) | |
ax1.set_xlabel(r"x/\pi") | |
ax1.set_xlim((min(xrange) / numpy.pi, max(xrange) / numpy.pi)) | |
ax1.set_ylabel(r"\sin(x)-P(x)") | |
L2_error = numpy.sin(xrange) - poly_approx(result_L2.x, xrange) | |
Linfty_error = numpy.sin(xrange) - poly_approx(result_Linfty.x, xrange) | |
Raj_error = numpy.sin(xrange) - 0.98553 * xrange + 0.14257 * xrange**3 | |
ax1.plot(xrange / numpy.pi, L2_error, label=r"L_2") | |
ax1.plot(xrange / numpy.pi, Linfty_error, label=r"L_\infty") | |
ax1.plot(xrange / numpy.pi, Raj_error, "--", label=r"Raj's") | |
ax1.legend(loc="upper left") | |
ax1.grid() | |
if args.plot_error_relation: | |
ax2 = ax1.twinx() | |
# ax2.plot(xrange / numpy.pi, | |
# 20 * numpy.log10(abs(L2_error / Linfty_error)), | |
# label="L2 vs L_infty") | |
ax2.plot(xrange / numpy.pi, | |
20 * numpy.log10(abs(Raj_error / Linfty_error)), | |
label="Raj's vs L_infty") | |
ax2.set_ylabel("first error / second error [dB]") | |
ax2.legend(loc="lower right") | |
ax2.grid() | |
fig.tight_layout() | |
if args.output: | |
pyplot.savefig(args.output) | |
else: | |
pyplot.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment