Skip to content

Instantly share code, notes, and snippets.

@wardev
Created February 25, 2014 13:31
Show Gist options
  • Save wardev/3f183dfaaf297dc2462c to your computer and use it in GitHub Desktop.
Save wardev/3f183dfaaf297dc2462c to your computer and use it in GitHub Desktop.
GaussNewtonOptimizerPerformanceTest
package org.apache.commons.math3.fitting.leastsquares;
import org.apache.commons.math3.fitting.leastsquares.GaussNewtonOptimizer.Decomposition;
import org.apache.commons.math3.geometry.euclidean.twod.Vector2D;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.DiagonalMatrix;
import org.apache.commons.math3.optim.SimpleVectorValueChecker;
import org.junit.Ignore;
import org.junit.Test;
import java.io.IOException;
import java.util.Arrays;
public class GaussNewtonOptimizerPerformanceTest {
//single thread compare two methods
@Test
@Ignore
public void testPerformanceST() throws IOException {
//domain setup
final int numPoints = 1000000;
CircleVectorial circle = new CircleVectorial();
double[] weights = new double[numPoints];
Arrays.fill(weights, 2);
RandomCirclePointGenerator generator = new RandomCirclePointGenerator(0.1, 0.1, 1.0, 0.1, 0.1, 1337);
Vector2D[] points = generator.generate(numPoints);
for (int i = 0; i < numPoints; ++i) {
circle.addPoint(points[i].getX(), points[i].getY());
}
final double[] start = {0, 0};
//problem
LeastSquaresProblem lsp = new LeastSquaresBuilder()
.checkerPair(new SimpleVectorValueChecker(1e-6, 1e-6))
.maxEvaluations(100)
.maxIterations(100)
.model(circle.getModeFunctionAndJacobian())
.target(new double[circle.getN()])
.weight(new DiagonalMatrix(weights))
.start(start)
.build();
//optimizers
GaussNewtonOptimizer lu = new GaussNewtonOptimizer(Decomposition.LU);
GaussNewtonOptimizer qr = new GaussNewtonOptimizer(Decomposition.QR);
GaussNewtonOptimizer chol = new GaussNewtonOptimizer(Decomposition.CHOLESKY);
// times to execute
final int n = 1;
//output format
final String format = "%5s: %10g s %20g%n";
// performance
while (true) {
// accumulator to introduce data dependence
double ac = 0.0;
long time = System.currentTimeMillis();
for (int i = 0; i < n; i++) {
ac += lu.optimize(lsp).getEvaluations();
}
long diff = System.currentTimeMillis() - time;
System.out.format(format, "lu", diff / 1e3, ac);
// accumulator to introduce data dependence
ac = 0.0;
time = System.currentTimeMillis();
for (int i = 0; i < n; i++) {
ac += qr.optimize(lsp).getEvaluations();
}
diff = System.currentTimeMillis() - time;
System.out.format(format, "qr", diff / 1e3, ac);
// accumulator to introduce data dependence
ac = 0.0;
time = System.currentTimeMillis();
for (int i = 0; i < n; i++) {
ac += chol.optimize(lsp).getEvaluations();
}
diff = System.currentTimeMillis() - time;
System.out.format(format, "chol", diff / 1e3, ac);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment