Skip to content

Instantly share code, notes, and snippets.

@Erkaman
Last active December 24, 2018 10:50
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save Erkaman/0ecc19026b9bca59e425074c1497f5da to your computer and use it in GitHub Desktop.
Save Erkaman/0ecc19026b9bca59e425074c1497f5da to your computer and use it in GitHub Desktop.
Accompanying source code for my article: https://erkaman.github.io/posts/gauss_newton.html
/*
The MIT License (MIT)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/
// NOTE: You need to add Eigen to your include paths to compile this.
#include <stdio.h>
#include <float.h>
#include <Eigen/Dense>
struct Sample {
double x;
double y;
};
const int PARAMS = 2; // our model has 2 parameters only.
double P[PARAMS]; // our current guess for the parameters.
const double DOMAIN_X_MIN = 0.0f;
const double DOMAIN_X_MAX = +1.0f;
const int N_SAMPLES = 10001; // we assume this to be an odd number, to simplify the code.
// keep improving our guess, until we meet the threshold,
const double RMSE_THRESHOLD = 0.001f;
const int MAX_ITERATIONS = 1000;
const double clog2 = log(2.0); // precompute this guy for performance.
Sample samples[N_SAMPLES];
Eigen::MatrixXd residual(N_SAMPLES, 1);
Eigen::MatrixXd J(N_SAMPLES, PARAMS);
// this is our "model".
// we are trying to find parameters P[] that result in a good approximation of the target.
double eval_model(double x) {
return pow(2, P[0] * x * x + P[1] * x);
}
// the target function is simply (1-x)^5.
double eval_target(double x) {
double d = 1.0 - x;
double d2 = d * d;
return d2 * d2 * d;
}
// jacobian of our model function.
Eigen::VectorXd calc_jacobian(double x) {
Eigen::VectorXd j(PARAMS);
j(0) = clog2 * x * x * pow(2, P[0] * x * x + P[1] * x);
j(1) = clog2 * x * pow(2, P[0] * x * x + P[1] * x);
return j;
}
int main() {
// our initial guess shall be (0,0)
P[0] = 0.0; P[1] = 0.0f;
const double DOMAIN_X_SPACING = (DOMAIN_X_MAX - DOMAIN_X_MIN) / (N_SAMPLES-1);
// evaluate our target at all sample points in the domain.
{
int index = 0;
for (double x = DOMAIN_X_MIN;
x <= DOMAIN_X_MAX;
x += DOMAIN_X_SPACING) {
Sample s;
s.x = x;
s.y = eval_target(s.x);
samples[index] = s;
index++;
}
}
for (int iteration = 0; iteration < MAX_ITERATIONS; ++iteration) {
// compute the residual and least squares error, and break if below threshold.
{
double rmse = 0.0;
for (int isample = 0; isample < N_SAMPLES; ++isample) {
Sample s = samples[isample];
residual(isample, 0) = (s.y - eval_model(s.x));
rmse += residual(isample, 0) * residual(isample, 0);
Eigen::VectorXd temp = calc_jacobian(s.x);
for (int iparam = 0; iparam < PARAMS; ++iparam) {
J(isample, iparam) = temp(iparam);
}
}
rmse = sqrt(rmse / float(N_SAMPLES));
printf("iteration: %d\n", iteration);
printf("RMSE: %f\n", rmse);
printf("param: %f %f\n", P[0], P[1]);
printf("\n");
if (rmse < RMSE_THRESHOLD) {
break;
}
}
/*
Here is the main part of the Gauss-Newton algorithm.
*/
Eigen::MatrixXd JtJ = J.transpose() * J;
Eigen::VectorXd h = (JtJ).colPivHouseholderQr().solve(J.transpose() * residual);
for (int iparam = 0; iparam < PARAMS; ++iparam) {
P[iparam] += h(iparam);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment