Created
November 1, 2017 21:27
-
-
Save christiaanjs/36aa366094d832a18c51004ff2155a2e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.ejml.data.DMatrixRMaj; | |
import org.ejml.dense.row.CommonOps_DDRM; | |
import org.ejml.simple.SimpleMatrix; | |
public class BFGSOptimizer implements Optimizer { | |
private double decreaseTol = 0.001; //(0, 0.5) | |
private double curvatureTol = 0.9; //(decreaseTol, 1.0) | |
private double gradTol = 1e-6; | |
private int maxIter = 1000; | |
private static final int MAX_LINE_SEARCH_ITER = 10000; | |
public BFGSOptimizer(){} | |
public BFGSOptimizer(double gradTol){ | |
this.gradTol = gradTol; | |
} | |
public BFGSOptimizer(double gradTol, double curvatureTol, double decreaseTol, int maxIter){ | |
this.decreaseTol = decreaseTol; | |
this.curvatureTol = curvatureTol; | |
this.gradTol = gradTol; | |
this.maxIter = maxIter; | |
if(decreaseTol > curvatureTol){ | |
throw new IllegalArgumentException("Invalid Wolfe parameters - line tol must be smaller than curvature tol"); | |
} | |
} | |
private SimpleObjectiveFunction f; | |
private DMatrixRMaj x; | |
private DMatrixRMaj prevGradOrChange; | |
private DMatrixRMaj grad; | |
private DMatrixRMaj prevHessInv; | |
private DMatrixRMaj hessInv; | |
private DMatrixRMaj hessUpdateComp; | |
private DMatrixRMaj step; | |
private int n; | |
private boolean restricted; | |
private double[] resCoeffs; | |
private double resConst; | |
@Override | |
public void init(SimpleObjectiveFunction f) { | |
this.f = f; | |
this.n = f.getN(); | |
x = new DMatrixRMaj(n, 1); | |
prevGradOrChange = new DMatrixRMaj(n, 1); | |
grad = new DMatrixRMaj(n, 1); | |
hessInv = CommonOps_DDRM.identity(n); | |
prevHessInv = new DMatrixRMaj(n, n); | |
hessUpdateComp = new DMatrixRMaj(n, n); | |
step = new DMatrixRMaj(n, 1); | |
restricted = f.isRestricted(); | |
if(restricted){ | |
resCoeffs = f.getRestrictionCoeffs(); | |
resConst = f.getRestrictionConst(); | |
} | |
} | |
@Override | |
public boolean optimize(double[] x){ | |
setX(x); | |
computeGradient(); | |
if(f instanceof SimpleHessianObjectiveFunction){ | |
DMatrixRMaj hessInit = new DMatrixRMaj(n, n); | |
((SimpleHessianObjectiveFunction) f).computeHessian(hessInit); | |
CommonOps_DDRM.invert(hessInit, hessInv); | |
} else { | |
CommonOps_DDRM.setIdentity(hessInv); | |
} | |
boolean converged = false; | |
for(int i = 0; !converged && i < maxIter; i++){ | |
computeSearchDirection(); | |
boolean lineSearchSuccess = performLineSearch(); | |
if(!lineSearchSuccess){ | |
return false; | |
} | |
DMatrixRMaj temp = prevGradOrChange; //Re-use previous object | |
prevGradOrChange = grad; | |
grad = temp; | |
computeGradient(); | |
if(checkConvergence()){ | |
converged = true; | |
} else { | |
performUpdate(); | |
} | |
} | |
return converged; | |
} | |
private void setX(double[] x){ | |
this.x.setData(x); | |
f.setInput(x); | |
} | |
private void computeGradient(){ | |
f.computeGradient(grad.data); | |
} | |
private void computeSearchDirection(){ | |
CommonOps_DDRM.mult(-1.0, hessInv, grad, step); | |
} | |
private double initVal; | |
private double initDeriv; | |
private double stepSize; | |
private int lineSearchIters; | |
private double stepUpper; | |
private double stepLower; | |
private boolean decreaseCondition; | |
private boolean curvatureCondition; | |
private void resetLineSearch(){ | |
f.setLine(x.data, step.data); | |
initVal = f.computeFunction(); | |
initDeriv = f.computeDerivative(); | |
lineSearchIters = 0; | |
stepLower = 0.0; | |
if(restricted){ //Enforce constraints | |
double stepSizeToBoundary = (resConst-Vector.dot(resCoeffs, x.data))/Vector.dot(resCoeffs, step.data); //Intersection of constraints hyperplane and step line | |
if(stepSizeToBoundary < 0){ //We won't intersect the boundary | |
stepUpper = Double.POSITIVE_INFINITY; | |
stepSize = 1.0; | |
} else { //We could intersect the boundary | |
stepUpper = stepSizeToBoundary; | |
stepSize = Math.min(1.0, (stepLower + stepUpper)/2.0); | |
} | |
} else { | |
stepUpper = Double.POSITIVE_INFINITY; | |
stepSize = 1.0; | |
} | |
f.setInput(stepSize); | |
decreaseCondition = decreaseCondition(); | |
curvatureCondition = curvatureCondition(); | |
} | |
private void lineSearchIter(){ | |
f.setInput(stepSize); | |
decreaseCondition = decreaseCondition(); | |
if(decreaseCondition) curvatureCondition = curvatureCondition(); //Always checked second | |
} | |
/** | |
* Bisection line search with Wolfe conditions | |
*/ | |
private boolean performLineSearch(){ | |
resetLineSearch(); | |
while(!(decreaseCondition && curvatureCondition)){ | |
if(!decreaseCondition){ | |
stepUpper = stepSize; | |
stepSize = (stepUpper + stepLower)/2; | |
} else { //!curvatureCondition | |
stepLower = stepSize; | |
if(Double.isFinite(stepUpper)){ | |
stepSize = (stepUpper + stepLower)/2; | |
} else { | |
stepSize = 2*stepLower; | |
} | |
} | |
lineSearchIter(); | |
if(lineSearchIters++ > MAX_LINE_SEARCH_ITER){ | |
return false; | |
} | |
} | |
return true; | |
} | |
private boolean decreaseCondition(){ | |
return f.computeFunction() <= initVal + stepSize*decreaseTol*initDeriv; | |
} | |
private boolean curvatureCondition(){ | |
return f.computeDerivative() >= curvatureTol*initDeriv; | |
} | |
private boolean checkConvergence(){ | |
return CommonOps_DDRM.dot(grad, grad) < gradTol; | |
} | |
private void performUpdate(){ | |
bfgsUpdate(n, hessInv, prevHessInv, grad, prevGradOrChange, step, stepSize, hessUpdateComp); | |
} | |
public static void bfgsUpdate(int n, DMatrixRMaj hessInv, | |
DMatrixRMaj prevHessInv, //Temp matrix | |
DMatrixRMaj grad, | |
DMatrixRMaj prevGrad, | |
DMatrixRMaj step, | |
double stepSize, | |
DMatrixRMaj tempMatrix){ | |
System.arraycopy(hessInv.data, 0, prevHessInv.data, 0, n*n); | |
CommonOps_DDRM.subtract(grad, prevGrad, prevGrad); //Set to change | |
double scale = stepSize*CommonOps_DDRM.dot(step, prevGrad); | |
double scale2 = scale*scale; | |
double firstTermScale = (scale + MatrixUtils.quadraticForm(prevHessInv, prevGrad.data))/scale2; | |
CommonOps_DDRM.multAddTransB(firstTermScale*stepSize*stepSize, step, step, hessInv); | |
double secondTermsScale = -1.0/scale; | |
CommonOps_DDRM.multTransB(stepSize, prevGrad, step, tempMatrix); //Could do this in place | |
CommonOps_DDRM.multAdd(secondTermsScale, prevHessInv, tempMatrix, hessInv); | |
CommonOps_DDRM.multAddTransA(secondTermsScale, tempMatrix, prevHessInv, hessInv); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment