Skip to content

Instantly share code, notes, and snippets.

@christiaanjs
Created November 1, 2017 21:27
Show Gist options
  • Save christiaanjs/36aa366094d832a18c51004ff2155a2e to your computer and use it in GitHub Desktop.
Save christiaanjs/36aa366094d832a18c51004ff2155a2e to your computer and use it in GitHub Desktop.
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import org.ejml.simple.SimpleMatrix;
public class BFGSOptimizer implements Optimizer {
private double decreaseTol = 0.001; //(0, 0.5)
private double curvatureTol = 0.9; //(decreaseTol, 1.0)
private double gradTol = 1e-6;
private int maxIter = 1000;
private static final int MAX_LINE_SEARCH_ITER = 10000;
public BFGSOptimizer(){}
public BFGSOptimizer(double gradTol){
this.gradTol = gradTol;
}
public BFGSOptimizer(double gradTol, double curvatureTol, double decreaseTol, int maxIter){
this.decreaseTol = decreaseTol;
this.curvatureTol = curvatureTol;
this.gradTol = gradTol;
this.maxIter = maxIter;
if(decreaseTol > curvatureTol){
throw new IllegalArgumentException("Invalid Wolfe parameters - line tol must be smaller than curvature tol");
}
}
private SimpleObjectiveFunction f;
private DMatrixRMaj x;
private DMatrixRMaj prevGradOrChange;
private DMatrixRMaj grad;
private DMatrixRMaj prevHessInv;
private DMatrixRMaj hessInv;
private DMatrixRMaj hessUpdateComp;
private DMatrixRMaj step;
private int n;
private boolean restricted;
private double[] resCoeffs;
private double resConst;
@Override
public void init(SimpleObjectiveFunction f) {
this.f = f;
this.n = f.getN();
x = new DMatrixRMaj(n, 1);
prevGradOrChange = new DMatrixRMaj(n, 1);
grad = new DMatrixRMaj(n, 1);
hessInv = CommonOps_DDRM.identity(n);
prevHessInv = new DMatrixRMaj(n, n);
hessUpdateComp = new DMatrixRMaj(n, n);
step = new DMatrixRMaj(n, 1);
restricted = f.isRestricted();
if(restricted){
resCoeffs = f.getRestrictionCoeffs();
resConst = f.getRestrictionConst();
}
}
@Override
public boolean optimize(double[] x){
setX(x);
computeGradient();
if(f instanceof SimpleHessianObjectiveFunction){
DMatrixRMaj hessInit = new DMatrixRMaj(n, n);
((SimpleHessianObjectiveFunction) f).computeHessian(hessInit);
CommonOps_DDRM.invert(hessInit, hessInv);
} else {
CommonOps_DDRM.setIdentity(hessInv);
}
boolean converged = false;
for(int i = 0; !converged && i < maxIter; i++){
computeSearchDirection();
boolean lineSearchSuccess = performLineSearch();
if(!lineSearchSuccess){
return false;
}
DMatrixRMaj temp = prevGradOrChange; //Re-use previous object
prevGradOrChange = grad;
grad = temp;
computeGradient();
if(checkConvergence()){
converged = true;
} else {
performUpdate();
}
}
return converged;
}
private void setX(double[] x){
this.x.setData(x);
f.setInput(x);
}
private void computeGradient(){
f.computeGradient(grad.data);
}
private void computeSearchDirection(){
CommonOps_DDRM.mult(-1.0, hessInv, grad, step);
}
private double initVal;
private double initDeriv;
private double stepSize;
private int lineSearchIters;
private double stepUpper;
private double stepLower;
private boolean decreaseCondition;
private boolean curvatureCondition;
private void resetLineSearch(){
f.setLine(x.data, step.data);
initVal = f.computeFunction();
initDeriv = f.computeDerivative();
lineSearchIters = 0;
stepLower = 0.0;
if(restricted){ //Enforce constraints
double stepSizeToBoundary = (resConst-Vector.dot(resCoeffs, x.data))/Vector.dot(resCoeffs, step.data); //Intersection of constraints hyperplane and step line
if(stepSizeToBoundary < 0){ //We won't intersect the boundary
stepUpper = Double.POSITIVE_INFINITY;
stepSize = 1.0;
} else { //We could intersect the boundary
stepUpper = stepSizeToBoundary;
stepSize = Math.min(1.0, (stepLower + stepUpper)/2.0);
}
} else {
stepUpper = Double.POSITIVE_INFINITY;
stepSize = 1.0;
}
f.setInput(stepSize);
decreaseCondition = decreaseCondition();
curvatureCondition = curvatureCondition();
}
private void lineSearchIter(){
f.setInput(stepSize);
decreaseCondition = decreaseCondition();
if(decreaseCondition) curvatureCondition = curvatureCondition(); //Always checked second
}
/**
* Bisection line search with Wolfe conditions
*/
private boolean performLineSearch(){
resetLineSearch();
while(!(decreaseCondition && curvatureCondition)){
if(!decreaseCondition){
stepUpper = stepSize;
stepSize = (stepUpper + stepLower)/2;
} else { //!curvatureCondition
stepLower = stepSize;
if(Double.isFinite(stepUpper)){
stepSize = (stepUpper + stepLower)/2;
} else {
stepSize = 2*stepLower;
}
}
lineSearchIter();
if(lineSearchIters++ > MAX_LINE_SEARCH_ITER){
return false;
}
}
return true;
}
private boolean decreaseCondition(){
return f.computeFunction() <= initVal + stepSize*decreaseTol*initDeriv;
}
private boolean curvatureCondition(){
return f.computeDerivative() >= curvatureTol*initDeriv;
}
private boolean checkConvergence(){
return CommonOps_DDRM.dot(grad, grad) < gradTol;
}
private void performUpdate(){
bfgsUpdate(n, hessInv, prevHessInv, grad, prevGradOrChange, step, stepSize, hessUpdateComp);
}
public static void bfgsUpdate(int n, DMatrixRMaj hessInv,
DMatrixRMaj prevHessInv, //Temp matrix
DMatrixRMaj grad,
DMatrixRMaj prevGrad,
DMatrixRMaj step,
double stepSize,
DMatrixRMaj tempMatrix){
System.arraycopy(hessInv.data, 0, prevHessInv.data, 0, n*n);
CommonOps_DDRM.subtract(grad, prevGrad, prevGrad); //Set to change
double scale = stepSize*CommonOps_DDRM.dot(step, prevGrad);
double scale2 = scale*scale;
double firstTermScale = (scale + MatrixUtils.quadraticForm(prevHessInv, prevGrad.data))/scale2;
CommonOps_DDRM.multAddTransB(firstTermScale*stepSize*stepSize, step, step, hessInv);
double secondTermsScale = -1.0/scale;
CommonOps_DDRM.multTransB(stepSize, prevGrad, step, tempMatrix); //Could do this in place
CommonOps_DDRM.multAdd(secondTermsScale, prevHessInv, tempMatrix, hessInv);
CommonOps_DDRM.multAddTransA(secondTermsScale, tempMatrix, prevHessInv, hessInv);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment