Created
April 23, 2013 00:16
-
-
Save anonymous/5439754 to your computer and use it in GitHub Desktop.
Strong wolfe line search
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import breeze.linalg._ | |
import breeze.numerics._ | |
import breeze.optimize.DiffFunction | |
abstract class LineSearch extends Log { | |
case class Bracket( | |
t : Double, // 1d line search parameter | |
dd : Double, // Directional Derivative at t | |
fval : Double // Function value at t | |
) | |
/* | |
* Invoke line search, returning stepsize | |
*/ | |
def search(f : DiffFunction[Double], props : OptProps) : Double | |
/* | |
* Cubic interpolation to find the minimum inside the bracket l and r. | |
* Uses the fval and gradient at the left and right side, which gives | |
* the four bits of information required to interpolate a cubic. | |
* This is additionally "safe-guarded" whereby steps too close to | |
* either side of the interval will not be selected. | |
*/ | |
def interp(l : Bracket, r : Bracket) = { | |
// See N&W p57 actual for an explanation of the math | |
val d1 = l.dd + r.dd - 3*(l.fval - r.fval)/(l.t - r.t) | |
val d2 = sqrt(d1 * d1 - l.dd * r.dd) | |
val multipler = r.t - l.t | |
val t = r.t - multipler * (r.dd + d2 - d1)/(r.dd - l.dd + 2*d2) | |
// If t is too close to either end bracket, move it closer to the middle | |
val lbound = l.t + 0.1*(r.t - l.t) | |
val ubound = l.t + 0.9*(r.t - l.t) | |
t match { | |
case _ if t < lbound => | |
debug("Cubic " + t + " below LHS limit: " + lbound) | |
lbound | |
case _ if t > ubound => | |
debug("Cubic " + t + " above RHS limit: " + ubound) | |
ubound | |
case _ => t | |
} | |
} | |
} | |
/* | |
* This line search will attempt steps larger than step length one, | |
* unlike back-tracking line searches. It also comes with strong convergence | |
* properties. It selects step lengths using cubic interpolation, which | |
* works better than other approaches in my experience. | |
* Based on Nocedal & Wright. | |
*/ | |
object StrongWolfeLineSearch extends LineSearch { | |
val c1 = 1e-4 | |
val c2 = 0.9 | |
/** | |
* Performs a line search on the function f, returning a point satisfying | |
* the Strong Wolfe conditions. Based on the line search detailed in | |
* Nocedal & Wright Numerical Optimization p58. | |
*/ | |
def search(f : DiffFunction[Double], | |
props : OptProps) : Double = { | |
def phi(t : Double) : Bracket = { | |
val (pval, pdd) = f.calculate(t) | |
Bracket(t=t, dd=pdd, fval=pval) | |
} | |
var t = 1.0 // Search's current multiple of pk | |
var low = phi(0.0) | |
val fval = low.fval | |
val dd = low.dd | |
if(dd > 0) { | |
throw new Exception("Line search invoked with non-descent direction: " + dd) | |
} | |
/** | |
* Assuming a point satisfying the strong wolfe conditions exists within | |
* the passed interval, this method finds it by iteratively refining the | |
* interval. Nocedal & Wright give the following invariants for zoom's loop: | |
* | |
* - The interval bounded by low.t and hi.t contains a point satisfying the | |
* strong Wolfe conditions. | |
* - Among all points evaluated so far that satisfy the "sufficient decrease" | |
* condition, low.t is the one with the smallest fval. | |
* - hi.t is chosen so that low.dd * (hi.t - low.t) < 0. | |
*/ | |
def zoom(linit: Bracket, rinit: Bracket): Double = { | |
var low = linit | |
var hi = rinit | |
for (i <- 0 until props.maxZoomIter) { | |
// Interp assumes left less than right in t value, so flip if needed | |
val t = if (low.t > hi.t) interp(hi, low) else interp(low, hi) | |
// Evaluate objective at t, and build bracket | |
val c = phi(t) | |
//debug("ZOOM:\n c: " + c + " \n l: " + low + " \nr: " + hi) | |
info("Line search t: " + t + " fval: " + c.fval + | |
" rhs: " + (fval + c1 * c.t * dd) + " cdd: " + c.dd) | |
/////////////// | |
/// Update left or right bracket, or both | |
if (c.fval > fval + c1 * c.t * dd || c.fval >= low.fval) { | |
// "Sufficient decrease" condition not satisfied by c. Shrink interval at right | |
hi = c | |
debug("hi=c") | |
} else { | |
// Zoom exit condition is the "curvature" condition | |
// Essentially that the directional derivative is large enough | |
if (abs(c.dd) <= c2 * abs(dd)) { | |
return c.t | |
} | |
// If the signs don't coincide, flip left to right before updating l to c | |
if (c.dd * (hi.t - low.t) >= 0) { | |
debug("flipping") | |
hi = low | |
} | |
debug("low=c") | |
// If curvature condition not satisfied, move the left hand side of the | |
// interval further away from t=0. | |
low = c | |
} | |
} | |
throw new Exception("Line search zoom failed") | |
} | |
/////////////////////////////////////////////////////////////////// | |
for(i <- 0 until props.maxLineSearchIter) { | |
val c = phi(t) | |
// If phi has a bounded domain, inf or nan usually indicates we took | |
// too large a step. | |
if(java.lang.Double.isInfinite(c.fval) || java.lang.Double.isNaN(c.fval)) { | |
t /= 2.0 | |
error("Encountered bad values in function evaluation. Decreasing step size to " + t) | |
} else { | |
// Zoom if "sufficient decrease" condition is not satisfied | |
if((c.fval > fval + c1*t*dd) || | |
(c.fval >= low.fval && i > 0)) { | |
debug("Line search t: " + t + " fval: " + c.fval + " cdd: " + c.dd) | |
return zoom(low, c) | |
} | |
// We don't need to zoom at all | |
// if the strong wolfe condition is satisfied already. | |
if (abs(c.dd) <= c2 * abs(dd)) { | |
return c.t | |
} | |
// If c.dd is positive, we zoom on the inverted interval. | |
// Occurs if we skipped over the nearest local minimum | |
// over to the next one. | |
if(c.dd >= 0) { | |
debug("Line search t: " + t + " fval: " + c.fval + | |
" rhs: " + (fval + c1*t*dd) + " cdd: " + c.dd) | |
return zoom(c, low) | |
} | |
low = c | |
t *= 1.5 | |
debug("Sufficent Decrease condition but not curvature condition satisfied. Increased t to: " + t) | |
} | |
} | |
throw new Exception("Line search failed") | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment