Last active
March 27, 2023 14:21
-
-
Save msakai/fcace38de0eb191d8f5fe9d33b116681 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# OPTIONS_GHC -Wall #-} | |
{-# LANGUAGE BangPatterns #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
-- translated from https://github.com/chokkan/liblbfgs | |
module LineSearch | |
( Params (..) | |
, defaultParams | |
, lineSearch | |
, lineSearchMoreThuente | |
) where | |
import qualified Numeric.LinearAlgebra as LA | |
import Numeric.LinearAlgebra ((<.>)) | |
clip :: Ord a => a -> a -> a -> a | |
clip lo hi x | |
| hi < x = hi | |
| x < lo = lo | |
| otherwise = x | |
midpoint :: Fractional a => a -> a -> a | |
midpoint x y = x + 0.5 * (y - x) | |
signdiff :: (Fractional a, Ord a) => a -> a -> Bool | |
signdiff x y = x * (y / abs y) < 0 | |
-- | The minimizer of the interpolated cubic. | |
cubicMinimizer | |
:: (Ord a, Floating a) | |
=> a -- ^ The value of one point, @u@. | |
-> a -- ^ The value of @f(u)@. | |
-> a -- ^ The value of @f'(u)@. | |
-> a -- ^ The value of another point, @v@. | |
-> a -- ^ The value of @f(v)@. | |
-> a -- ^ The value of @f'(v)@. | |
-> a | |
cubicMinimizer u fu du v fv dv = u + r * d | |
where | |
d = v - u | |
theta = (fu - fv) * 3 / d + du + dv | |
s = maximum [p, q, r] | |
where | |
p = abs theta | |
q = abs du | |
r = abs dv | |
a = theta / s | |
gamma = (if v < u then negate else id) $ | |
s * sqrt (a * a - (du / s) * (dv / s)) | |
p = gamma - du + theta | |
q = gamma - du + gamma + dv | |
r = p / q | |
-- | The minimizer of the interpolated cubic. | |
cubicMinimizer2 | |
:: (Ord a, Floating a) | |
=> a -- ^ The value of one point, @u@. | |
-> a -- ^ The value of @f(u)@. | |
-> a -- ^ The value of @f'(u)@. | |
-> a -- ^ The value of another point, @v@. | |
-> a -- ^ The value of @f(v)@. | |
-> a -- ^ The value of @f'(v)@. | |
-> a -- ^ The minimum value. | |
-> a -- ^ The maximum value. | |
-> a | |
cubicMinimizer2 u fu du v fv dv xmin xmax | |
| r < 0.0 && gamma /= 0.0 = v - r * d | |
| a < 0 = xmax | |
| otherwise = xmin | |
where | |
d = v - u | |
theta = (fu - fv) * 3 / d + du + dv | |
s = maximum [p, q, r] | |
where | |
p = abs theta | |
q = abs du | |
r = abs dv | |
a = theta / s | |
gamma = (if u < v then negate else id) $ | |
s * sqrt (max 0 (a * a - (du / s) * (dv / s))) | |
p = gamma - dv + theta | |
q = gamma - dv + gamma + du | |
r = p / q | |
quadMinimizer | |
:: (Ord a, Fractional a) | |
=> a -- ^ The value of one point, @u@. | |
-> a -- ^ The value of @f(u)@. | |
-> a -- ^ The value of @f'(u)@. | |
-> a -- ^ The value of another point, @v@. | |
-> a -- ^ The value of @f(v)@. | |
-> a | |
quadMinimizer u fu du v fv = u + du / ((fu - fv) / a + du) / 2 * a | |
where | |
a = v - u | |
quadMinimizer2 | |
:: (Ord a, Fractional a) | |
=> a -- ^ The value of one point, @u@. | |
-> a -- ^ The value of @f'(u)@. | |
-> a -- ^ The value of another point, @v@. | |
-> a -- ^ The value of @f'(v)@. | |
-> a | |
quadMinimizer2 u du v dv = v + dv / (dv - du) * a | |
where | |
a = u - v | |
data Error | |
= ERR_OUTOFINTERVAL | |
| ERR_INCREASEGRADIENT | |
| ERR_INCORRECT_TMINMAX | |
| ERR_INVALIDPARAMETERS | |
| ERR_MAXIMUMSTEP | |
| ERR_MINIMUMSTEP | |
| ERR_ROUNDING_ERROR | |
| ERR_WIDTHTOOSMALL | |
| ERR_MAXIMUMLINESEARCH | |
deriving (Eq, Ord, Enum, Bounded, Show) | |
{- | Update a safeguarded trial value and interval for line search. | |
The parameter @x@ represents the step with the least function value. | |
The parameter @t@ represents the current step. This function assumes | |
that the derivative at the point of @x@ in the direction of the step. | |
If the bracket is set to true, the minimizer has been bracketed in | |
an interval of uncertainty with endpoints between @x@ and @y@. | |
-} | |
updateTrialInterval | |
:: (Ord a, Floating a) | |
=> (a, a, a) -- ^ The value of one endpoint @x@, the value of @f(x)@ and the value of @f'(x)@ | |
-> (a, a, a) -- ^ The value of another endpoint @y@, the value of @f(y)@ and the value of @f'(y)@ | |
-> (a, a, a) -- ^ The value of the trial value @t@, the value of @f(t)@ and the value of @f'(t)@ | |
-> a -- ^ The minimum value for the trial value, @t@. | |
-> a -- ^ The maximum value for the trial value, @t@. | |
-> Bool -- ^ The predicate if the trial value is bracketed. | |
-> Either Error (a, (a, a, a), (a, a, a), Bool) | |
-- ^ 'Error' or new trial value, updated @(x, f(x), f'(x))@, updated @(y, f(y), f'(y))@, and updated bracketed predicate. | |
updateTrialInterval (x, fx, dx) (y, fy, dy) (t, ft, dt) tmin tmax brackt | |
| brackt && (t <= min x y || max x y <= t) = | |
-- The trival value t is out of the interval. | |
Left ERR_OUTOFINTERVAL | |
| brackt && (0 <= dx * (t - x)) = | |
-- The function must decrease from x. | |
Left ERR_INCREASEGRADIENT | |
| brackt && tmax < tmin = | |
-- Incorrect tmin and tmax specified. | |
Left ERR_INCORRECT_TMINMAX | |
| otherwise = Right (newt3, (x', fx', dx'), (y', fy', dy'), brackt') | |
where | |
dsign = signdiff dt dx | |
(newt1, brackt', bound) | |
| fx < ft = | |
{- | |
Case 1: a higher function value. | |
The minimum is brackt. If the cubic minimizer is closer | |
to x than the quadratic one, the cubic one is taken, else | |
the average of the minimizers is taken. | |
-} | |
let mc = cubicMinimizer x fx dx t ft dt | |
mq = quadMinimizer x fx dx t ft | |
in ( if abs (mc - x) < abs (mq - x) | |
then mc | |
else midpoint mc mq | |
, True | |
, True | |
) | |
| dsign = | |
{- | |
Case 2: a lower function value and derivatives of | |
opposite sign. The minimum is brackt. If the cubic | |
minimizer is closer to x than the quadratic (secant) one, | |
the cubic one is taken, else the quadratic one is taken. | |
-} | |
let mc = cubicMinimizer x fx dx t ft dt | |
mq = quadMinimizer2 x dx t dt | |
in ( if abs (mc - t) > abs (mq - t) | |
then mc | |
else mq | |
, True | |
, False | |
) | |
| abs dt < abs dx = | |
{- | |
Case 3: a lower function value, derivatives of the | |
same sign, and the magnitude of the derivative decreases. | |
The cubic minimizer is only used if the cubic tends to | |
infinity in the direction of the minimizer or if the minimum | |
of the cubic is beyond t. Otherwise the cubic minimizer is | |
defined to be either tmin or tmax. The quadratic (secant) | |
minimizer is also computed and if the minimum is brackt | |
then the the minimizer closest to x is taken, else the one | |
farthest away is taken. | |
-} | |
let mc = cubicMinimizer2 x fx dx t ft dt tmin tmax | |
mq = quadMinimizer2 x dx t dt | |
in ( if brackt then | |
if abs (t - mc) < abs (t - mq) | |
then mc | |
else mq | |
else | |
if abs (t - mc) > abs (t - mq) | |
then mc | |
else mq | |
, brackt | |
, True | |
) | |
| otherwise = | |
{- | |
Case 4: a lower function value, derivatives of the | |
same sign, and the magnitude of the derivative does | |
not decrease. If the minimum is not brackt, the step | |
is either tmin or tmax, else the cubic minimizer is taken. | |
-} | |
( if brackt then | |
cubicMinimizer t ft dt y fy dy | |
else if x < t then | |
tmax | |
else | |
tmin | |
, brackt | |
, False | |
) | |
{- | |
Update the interval of uncertainty. This update does not | |
depend on the new step or the case analysis above. | |
- Case a: if f(x) < f(t), | |
x <- x, y <- t. | |
- Case b: if f(t) <= f(x) && f'(t)*f'(x) > 0, | |
x <- t, y <- y. | |
- Case c: if f(t) <= f(x) && f'(t)*f'(x) < 0, | |
x <- t, y <- x. | |
-} | |
(x', fx', dx', y', fy', dy') | |
| fx < ft = (x, fx, dx, t, ft, dt) -- Case a | |
| dsign = (t, ft, dt, x, fx, dx) -- Case c | |
| otherwise = (t, ft, dt, y, fy, dy) -- Case b | |
-- Clip the new trial value in [tmin, tmax]. | |
newt2 = clip tmin tmax newt1 | |
-- Redefine the new trial value if it is close to the upper bound of the interval. | |
newt3 | |
| brackt' && bound && (if x' < y' then mq < newt2 else newt2 < mq) = mq | |
| otherwise = newt2 | |
where | |
delta = 0.66 | |
mq = x' + delta * (y' - x') | |
data Params a | |
= Params | |
{ paramsMinStep :: a | |
-- ^ The minimum step of the line search routine. | |
-- | |
-- The default value is @1e-20@. This value need not be modified unless | |
-- the exponents are too large for the machine being used, or unless the | |
-- problem is extremely badly scaled (in which case the exponents should | |
-- be increased). | |
, paramsMaxStep :: a | |
-- ^ The maximum step of the line search. | |
-- | |
-- The default value is @1e+20@. This value need not be modified unless | |
-- the exponents are too large for the machine being used, or unless the | |
-- problem is extremely badly scaled (in which case the exponents should | |
-- be increased). | |
, paramsFTol :: a | |
-- ^ A parameter to control the accuracy of the line search routine. | |
-- | |
-- The default value is @1e-4@. This parameter should be greater | |
-- than zero and smaller than @0.5@. | |
-- "μ" in [MoreThuente1994]. | |
, paramsGTol :: a | |
-- ^ A parameter to control the accuracy of the line search routine. | |
-- | |
-- The default value is @0.9@. If the function and gradient | |
-- evaluations are inexpensive with respect to the cost of the | |
-- iteration (which is sometimes the case when solving very large | |
-- problems) it may be advantageous to set this parameter to a small | |
-- value. A typical small value is @0.1@. This parameter should be | |
-- greater than the 'paramsFTol' parameter (@1e-4@) and smaller than | |
-- @1.0@. | |
-- "η" in [MoreThuente1994]. | |
, paramsXTol :: a | |
-- ^ The machine precision for floating-point values. | |
-- | |
-- This parameter must be a positive value set by a client program to | |
-- estimate the machine precision. The line search routine will terminate | |
-- with the status code ('ERR_ROUNDING_ERROR') if the relative width | |
-- of the interval of uncertainty is less than this parameter. | |
, paramsMaxLineSearch :: Int | |
-- ^ The maximum number of trials for the line search. | |
-- | |
-- This parameter controls the number of function and gradients evaluations | |
-- per iteration for the line search routine. The default value is @40@. | |
} | |
defaultParams :: Floating a => Params a | |
defaultParams | |
= Params | |
{ paramsMinStep = 1e-20 | |
, paramsMaxStep = 1e+20 | |
, paramsFTol = 1e-4 | |
, paramsGTol = 0.9 | |
, paramsXTol = 1.0e-16 | |
, paramsMaxLineSearch = 40 | |
} | |
lineSearch | |
:: forall a. (Ord a, Floating a, LA.Numeric a) | |
=> Params a | |
-> (LA.Vector a -> (a, LA.Vector a)) | |
-> (LA.Vector a, a, LA.Vector a) | |
-> LA.Vector a | |
-> a | |
-> (Maybe Error, a, (LA.Vector a, a, LA.Vector a)) | |
lineSearch = lineSearchMoreThuente | |
lineSearchMoreThuente | |
:: forall a. (Ord a, Floating a, LA.Numeric a) | |
=> Params a | |
-> (LA.Vector a -> (a, LA.Vector a)) | |
-> (LA.Vector a, a, LA.Vector a) | |
-> LA.Vector a | |
-> a | |
-> (Maybe Error, a, (LA.Vector a, a, LA.Vector a)) | |
lineSearchMoreThuente params evaluate (x0, f0, g0) s step0 | |
| step0 < 0 = (Just ERR_INVALIDPARAMETERS, 0, (x0, f0, g0)) | |
| 0 < dg0 = (Just ERR_INCREASEGRADIENT, 0, (x0, f0, g0)) | |
| otherwise = seq dgtest $ go 0 Nothing False True width0 prevWidth0 (0, f0, dg0) (0, f0, dg0) step0 | |
where | |
-- φ(step) = f(x0 + step * s) = f(x) | |
dg0 = g0 <.> s -- φ'(0) | |
dgtest = paramsFTol params * dg0 -- μ φ'(0) | |
width0 = paramsMaxStep params - paramsMinStep params | |
prevWidth0 = 2.0 * width0 | |
go | |
:: Int | |
-> Maybe Error | |
-> Bool -> Bool | |
-> a | |
-> a | |
-> (a, a, a) | |
-> (a, a, a) | |
-> a | |
-> (Maybe Error, a, (LA.Vector a, a, LA.Vector a)) | |
go count uinfo brackt stage1 width prev_width (stx, fx, dgx) (sty, fy, dgy) step_ | |
| brackt && (step <= stmin || stmax <= step || uinfo /= Nothing) = | |
-- Rounding errors prevent further progress. | |
(Just ERR_ROUNDING_ERROR, step, (x, f, g)) | |
| step == paramsMaxStep params && sufficientDecrease && dg <= dgtest = | |
-- The step is the maximum value. | |
(Just ERR_MAXIMUMSTEP, step, (x, f, g)) | |
| step == paramsMinStep params && (not sufficientDecrease || dgtest <= dg) = | |
-- The step is the minimum value. | |
(Just ERR_MINIMUMSTEP, step, (x, f, g)) | |
| brackt && stmax - stmin <= paramsXTol params * stmax = | |
-- Relative width of the interval of uncertainty is at most xtol. | |
(Just ERR_WIDTHTOOSMALL, step, (x, f, g)) | |
| paramsMaxLineSearch params <= count + 1 = | |
-- Maximum number of iteration | |
(Just ERR_MAXIMUMLINESEARCH, step, (x, f, g)) | |
| sufficientDecrease && abs dg <= paramsGTol params * (-dg0) = | |
-- The sufficient decrease condition and the directional derivative condition hold. | |
(Nothing, step, (x, f, g)) | |
| otherwise = | |
go (count + 1) uinfo' brackt' stage1' width' prev_width' (stx', fx', dgx') (sty', fy', dgy') step'' | |
where | |
{- | |
Set the minimum and maximum steps to correspond to the | |
present interval of uncertainty. | |
-} | |
(stmin, stmax) | |
| brackt = (min stx sty, max stx sty) | |
| otherwise = (stx, step_ + 4 * (step_ - stx)) | |
-- Clip the step in the range of [stepmin, stepmax]. | |
step = f $ clip (paramsMinStep params) (paramsMaxStep params) $ step_ | |
where | |
{- If an unusual termination is to occur then let | |
step be the lowest point obtained so far. | |
-} | |
f step | |
| brackt && (step <= stmin || stmax <= step || paramsMaxLineSearch params <= count + 1 || uinfo /= Nothing) = stx | |
| brackt && (stmax - stmin <= paramsXTol params * stmax) = stx | |
| otherwise = step | |
-- Compute the current value of x | |
x = x0 `LA.add` LA.scale step s | |
-- Evaluate the function and gradient values. | |
(f, g) = evaluate x | |
dg = g <.> s | |
-- φ(α) <= φ(0) + α μ φ'(0) | |
sufficientDecrease :: Bool | |
sufficientDecrease = f <= f0 + step * dgtest | |
{- | |
In the first stage we seek a step for which the modified | |
function has a nonpositive value and nonnegative derivative. | |
ψ(α) <= 0 | |
⇔ φ(α) - φ(0) - μ φ'(0) <= 0 | |
⇔ φ(α) <= φ(0) + μ φ'(0) | |
⇔ sufficientDecrease | |
ψ'(α) >= 0 | |
⇔ φ'(α) >= 0 | |
で、min(μ,η) φ'(0) <= φ'(α) の左辺は少し余裕を持たせている? | |
ftol (μ) の方が gtol (η) より小さい前提のはずなのに min をとっているのは何故? | |
-} | |
stage1' | |
| stage1 && sufficientDecrease && min (paramsFTol params) (paramsGTol params) * dg0 <= dg = False | |
| otherwise = stage1 | |
(uinfo', step', (stx', fx', dgx'), (sty', fy', dgy'), brackt') | |
| stage1' && not sufficientDecrease && f <= fx = | |
{- A modified function is used to predict the step only if | |
we have not obtained a step for which the modified | |
function has a nonpositive function value and nonnegative | |
derivative, and if a lower function value has been | |
obtained but the decrease is not sufficient. | |
-} | |
let -- Define the modified function and derivative values. | |
-- dgtest = μ φ'(0) で | |
-- ψ(α) = φ(α) - φ(0) - μ φ'(0) α だとすると - φ(0) の項は定数差なので無視されている? | |
fm = f - step * dgtest | |
fxm = fx - stx * dgtest | |
fym = fy - sty * dgtest | |
dgm = dg - dgtest | |
dgxm = dgx - dgtest | |
dgym = dgy - dgtest | |
in case updateTrialInterval (stx, fxm, dgxm) (sty, fym, dgym) (step, fm, dgm) stmin stmax brackt of | |
Left err -> | |
( Just err | |
, step | |
, (stx, fx, dgx) | |
, (sty, fy, dgy) | |
, brackt | |
) | |
Right (step', (stx', fxm', dgxm'), (sty', fym', dgym'), brackt') -> | |
( Nothing | |
, step' | |
, (stx', fxm' + stx' * dgtest, dgxm' + dgtest) | |
, (sty', fym' + sty' * dgtest, dgym' + dgtest) | |
, brackt' | |
) | |
| otherwise = | |
case updateTrialInterval (stx, fx, dgx) (sty, fy, dgy) (step, f, dg) stmin stmax brackt of | |
Left err -> | |
( Just err | |
, step | |
, (stx, fx, dgx) | |
, (sty, fy, dgy) | |
, brackt | |
) | |
Right (step', (stx', fx', dgx'), (sty', fy', dgy'), brackt') -> | |
( Nothing | |
, step' | |
, (stx', fx', dgx') | |
, (sty', fy', dgy') | |
, brackt' | |
) | |
(step'', prev_width', width') | |
| brackt' = | |
( if 0.66 * prev_width <= abs (sty' - stx') | |
then midpoint stx' sty' | |
else step' | |
, width | |
, abs (sty' - stx') | |
) | |
| otherwise = (step', prev_width, width) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# OPTIONS_GHC -Wall #-} | |
{-# LANGUAGE BangPatterns #-} | |
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE RankNTypes #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
module QuadProg | |
( QuadProg (..) | |
, evalQuadProgObj | |
, evalQuadProg | |
, solveQuadProg | |
, test_unconstrained_qp | |
, test_unconstrained_qp_2 | |
, test_no_minimum | |
, test_nonbinding_constraints | |
, test_some_binding_constraints | |
, test_some_binding_constraints_2 | |
) where | |
import Control.Monad | |
import Control.Exception (assert) | |
import Data.IntSet (IntSet) | |
import qualified Data.IntSet as IntSet | |
import qualified Data.Vector.Generic as VG | |
import Numeric.LinearAlgebra | |
-- | Quadratic Programming problem: $min \{\frac{1}{2} x^T Q x + c^T x \mid A x \le b\}$ | |
data QuadProg a = | |
QuadProg | |
!(Matrix a) -- ^ $Q \in R^{n\times n}$ | |
!(Vector a) -- ^ $c \in R^n$ | |
!(Matrix a) -- ^ $A \in R^{m\times n}$ | |
!(Vector a) -- ^ $b \in R^m$ | |
deriving (Show) | |
evalQuadProgObj | |
:: forall a. (Field a, Ord a, Normed (Vector a), Show a) | |
=> QuadProg a -> Vector a -> a | |
evalQuadProgObj (QuadProg qs c _ _) x = (x <.> (qs #> x)) / 2 + (c <.> x) | |
evalQuadProg | |
:: forall a. (Field a, Ord a, Normed (Vector a), Show a) | |
=> QuadProg a -> Vector a -> a -> Maybe a | |
evalQuadProg qp@(QuadProg _ _ as b) x tol = do | |
guard $ VG.all (>= - tol) (b `sub` (as #> x)) | |
return $ evalQuadProgObj qp x | |
-- | Solve Quadratic Preogramming (QP) problem using Active Set Method | |
-- | |
-- http://www.fujilab.dnj.ynu.ac.jp/lecture/system5.pdf | |
solveQuadProg | |
:: forall a. (Field a, Ord a, Normed (Vector a), Show a) | |
=> QuadProg a -- ^ QP problem | |
-> Vector a -- ^ initial solution | |
-> [Vector a] | |
solveQuadProg (QuadProg qs c as b) x0 | |
| not (and [size qs == (n,n), size c == n, size as == (m,n), size b == m, size x0 == n]) = error "dimention mismatch" | |
-- TODO: check positive-definiteness of Q | |
| VG.any (< -tol) slack0 = error "infeasible initial solution" | |
| otherwise = go ws0 x0 | |
where | |
tol = 1e-8 | |
n = VG.length x0 | |
m = rows as | |
slack0 = b `sub` (as #> x0) | |
wsAll = IntSet.fromList [0 .. m-1] | |
ws0 = IntSet.fromList [i | i <- [0..m-1], slack0 VG.! i < tol] | |
go :: IntSet -> Vector a -> [Vector a] | |
go !ws !x = assert (size as' == (m',n)) $ assert (size x' == n) $ assert (size y' == m') $ (x :) $ | |
if any (\(_, alpha) -> alpha < 1) alphas then | |
let alpha = minimum (map snd alphas) | |
ws'' = ws `IntSet.union` IntSet.fromList [i | (i, alpha') <- alphas, alpha' <= alpha] | |
x'' = x `add` scale alpha d | |
in go ws'' x'' | |
else if VG.all (> - tol) y' then | |
-- converged | |
let _y :: Vector a | |
_y = VG.replicate m 0 VG.// zip (IntSet.toAscList ws) (VG.toList y') | |
in [x'] | |
else | |
go (IntSet.delete (wsV VG.! VG.minIndex y') ws) x | |
where | |
m' = IntSet.size ws | |
wsV :: Vector Int | |
wsV = VG.fromListN m' (IntSet.toAscList ws) | |
as' = as ? IntSet.toAscList ws | |
b' = VG.fromListN m' [b VG.! i | i <- IntSet.toAscList ws] | |
(x',y') = VG.splitAt n (mat <\> (scale (-1) c VG.++ b')) | |
where | |
mat = | |
qs ||| tr' as' | |
=== | |
as' ||| konst 0 (m',m') | |
d = x' `sub` x | |
alphas = | |
[ (i, alpha) | |
| i <- IntSet.toList (wsAll IntSet.\\ ws) | |
, let as_i = as ! i | |
, let v_i = as_i <.> d | |
, v_i > 0 | |
, let alpha = ((b VG.! i) - as_i <.> x) / v_i | |
] | |
sub :: (Additive (c t), Linear t c, Num t) => c t -> c t -> c t | |
sub x y = x `add` scale (-1) y | |
-- https://www.fsb.miamioh.edu/lij14/400_slide_qp.pdf | |
-- Example 1: unconstrained QP | |
test_unconstrained_qp = (x, evalQuadProgObj prob x) -- ((4,2), -32) | |
where | |
prob :: QuadProg Double | |
prob = QuadProg | |
((2 >< 2) [2,0,0,8]) | |
(VG.fromList [-8,-16]) | |
((1 >< 2) [0,0]) | |
(VG.fromList [0]) | |
x0 = VG.fromList [0,0] | |
x = last $ solveQuadProg prob x0 | |
test_unconstrained_qp_2 = (x, evalQuadProgObj prob x) -- ((4,2), -32) | |
where | |
prob :: QuadProg Double | |
prob = QuadProg | |
((2 >< 2) [2,0,0,8]) | |
(VG.fromList [-8,-16]) | |
((0 >< 2) []) | |
(VG.fromList []) | |
x0 = VG.fromList [0,0] | |
x = last $ solveQuadProg prob x0 | |
-- Example 2: a QP problem that has no minimum | |
test_no_minimum = (x, evalQuadProgObj prob x) -- should be error, but ... | |
where | |
prob :: QuadProg Double | |
prob = QuadProg | |
((2 >< 2) [2,4,4,8]) | |
(VG.fromList [-8,-16]) | |
((1 >< 2) [0,0]) | |
(VG.fromList [0]) | |
x0 = VG.fromList [0,0] | |
x = last $ solveQuadProg prob x0 | |
-- Example 3: Constrained QP with non-binding constraints | |
test_nonbinding_constraints = (x, evalQuadProgObj prob x) -- ((4,2), -32) | |
where | |
prob :: QuadProg Double | |
prob = QuadProg | |
((2 >< 2) [2,0,0,8]) | |
(VG.fromList [-8,-16]) | |
((2 >< 2) [-1,-1,-1,0]) | |
(VG.fromList [-5,-3]) | |
x0 = VG.fromList [10,10] | |
x = last $ solveQuadProg prob x0 | |
-- Example 4: Some constraints are binding | |
test_some_binding_constraints = (x, evalQuadProgObj prob x) -- ((4.5, 2), -31.75) | |
where | |
prob :: QuadProg Double | |
prob = QuadProg | |
((2 >< 2) [2,0,0,8]) | |
(VG.fromList [-8,-16]) | |
((2 >< 2) [-1,-1,-1,0]) | |
(VG.fromList [-5, -4.5]) | |
x0 = VG.fromList [10, 10] | |
x = last $ solveQuadProg prob x0 | |
-- Example 5: Some constraints are binding | |
test_some_binding_constraints_2 = (x, evalQuadProgObj prob x) -- ((4.8, 2.2), -31.2) | |
where | |
prob :: QuadProg Double | |
prob = QuadProg | |
((2 >< 2) [2,0,0,8]) | |
(VG.fromList [-8,-16]) | |
((2 >< 2) [-1,-1,-1,0]) | |
(VG.fromList [-7, -3]) | |
x0 = VG.fromList [10, 10] | |
x = last $ solveQuadProg prob x0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# OPTIONS_GHC -Wall #-} | |
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE RankNTypes #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
module SecondOrder | |
( newtonMethod | |
, gaussNewton | |
, levenbergMarquardt | |
, bfgs | |
, bfgsV | |
, lbfgs | |
, lbfgsV | |
, rosenbrock | |
) where | |
import qualified Data.Foldable as F | |
import Data.Reflection (Reifies) | |
import Data.Sequence (Seq, ViewL (..), (<|)) | |
import qualified Data.Sequence as Seq | |
import qualified Data.Traversable as T | |
import qualified Data.Vector.Generic as VG | |
import Foreign.Storable | |
import Numeric.AD | |
import Numeric.AD.Internal.Reverse (Reverse, Tape) | |
import Numeric.AD.Rank1.Sparse (Sparse) | |
import Numeric.LinearAlgebra | |
import qualified LineSearch as LS | |
zipWithTV :: (Traversable f, Storable b) => (a -> b -> c) -> f a -> Vector b -> f c | |
zipWithTV u x v = snd $ T.mapAccumL (\i x_i -> (i+1, u x_i (v VG.! i))) 0 x | |
-- | Perform a newton method. | |
-- | |
-- >>> let sq x = x * x | |
-- >>> let rosenbrock [x,y] = sq (1 - x) + 100 * sq (y - sq x) | |
-- >>> rosenbrock [0,0] | |
-- 1 | |
-- >>> rosenbrock (newtonMethod rosenbrock [0, 0] !! 2) < 0.1 | |
-- True | |
newtonMethod | |
:: forall f a. (Traversable f, Field a) | |
=> (forall s. f (AD s (Sparse a)) -> AD s (Sparse a)) | |
-> f a -> [f a] | |
newtonMethod f x0 = go x0 | |
where | |
n = length x0 | |
go :: f a -> [f a] | |
go x = x : go (zipWithTV (-) x (h <\> g)) | |
where | |
_y :: a | |
gh :: f (a, f a) | |
(_y, gh) = hessian' f x | |
g :: Vector a | |
g = fromList $ map fst $ F.toList gh | |
h :: Matrix a | |
h = (n >< n) $ concat $ map (F.toList . snd) $ F.toList gh | |
gaussNewton | |
:: forall f g a. (Traversable f, Traversable g, Field a) | |
=> (forall s. Reifies s Tape => f (Reverse s a) -> g (Reverse s a)) | |
-> f a -> [f a] | |
gaussNewton f x0 = go x0 | |
where | |
m = length x0 | |
go :: f a -> [f a] | |
go x = x : go (zipWithTV (-) x (pinv j #> r)) | |
where | |
rj :: g (a, f a) | |
rj = jacobian' f x | |
r :: Vector a | |
r = fromList $ map fst $ F.toList rj | |
j :: Matrix a | |
j = (length rj >< m) $ concat $ map (F.toList . snd) (F.toList rj) | |
-- example from https://en.wikipedia.org/wiki/Gauss%E2%80%93Newton_algorithm | |
test_gaussNewton = gaussNewton f x0 | |
where | |
f :: Fractional a => [a] -> [a] | |
f [beta1,beta2] = [y - beta1*x / (beta2 + x) | (x,y) <- zip xs ys] | |
where | |
xs = [0.038, 0.194, 0.425, 0.626, 1.253, 2.500, 3.740] | |
ys = [0.050, 0.127, 0.094, 0.2122, 0.2729, 0.2665, 0.3317] | |
x0 :: [Double] | |
x0 = [0.9, 0.2] | |
-- | Levenberg–Marquardt algorithm with Tikhonov Dampling | |
levenbergMarquardt | |
:: forall f g a. (Traversable f, Traversable g, Field a, Ord a) | |
=> a | |
-> (forall s. Reifies s Tape => f (Reverse s a) -> g (Reverse s a)) | |
-> f a -> [f a] | |
levenbergMarquardt lambda0 f x0 = go lambda0 x0 | |
where | |
m = length x0 | |
go :: a -> f a -> [f a] | |
go lambda x = x : go lambda' x' | |
where | |
rj :: g (a, f a) | |
rj = jacobian' f x | |
n = length rj | |
r :: Vector a | |
r = fromList $ map fst $ F.toList rj | |
j :: Matrix a | |
j = (n >< m) $ concat $ map (F.toList . snd) (F.toList rj) | |
gnMat :: Matrix a | |
gnMat = unSym (mTm j) `add` diag (VG.replicate m lambda) | |
delta :: Vector a | |
delta = gnMat <\> scale (-1) (r <# j) | |
x' :: f a | |
x' = zipWithTV (+) x delta | |
loss, loss' :: a | |
loss = sum [r_j * r_j | r_j <- toList r] / fromIntegral n | |
loss' = sum [r_j * r_j | (r_j, _) <- F.toList (jacobian' f x')] / fromIntegral n | |
approx :: Vector a -> a | |
approx delta | |
= loss | |
+ (scale (2 / fromIntegral n) r <# j) `dot` delta | |
+ (delta `dot` (scale (2 / fromIntegral n) gnMat #> delta)) / 2 | |
rho :: a | |
rho = (loss' - loss) / (approx delta - approx (VG.replicate m 0)) | |
lambda' :: a | |
lambda' | |
| rho > 3/4 = lambda * 2 / 3 | |
| rho < 1/4 = lambda * 3 / 2 | |
| otherwise = lambda | |
-- example from https://en.wikipedia.org/wiki/Gauss%E2%80%93Newton_algorithm | |
test_levenbergMarquardt = levenbergMarquardt 1.0 f x0 | |
where | |
f :: Fractional a => [a] -> [a] | |
f [beta1,beta2] = [y - beta1*x / (beta2 + x) | (x,y) <- zip xs ys] | |
where | |
xs = [0.038, 0.194, 0.425, 0.626, 1.253, 2.500, 3.740] | |
ys = [0.050, 0.127, 0.094, 0.2122, 0.2729, 0.2665, 0.3317] | |
x0 :: [Double] | |
x0 = [0.9, 0.2] | |
-- | Broyden–Fletcher–Goldfarb–Shanno algorithm | |
-- | |
-- https://en.wikipedia.org/wiki/Broyden%E2%80%93Fletcher%E2%80%93Goldfarb%E2%80%93Shanno_algorithm | |
bfgs | |
:: forall f a. (Traversable f, Field a, Ord a, Normed (Vector a), Show a) | |
=> (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) | |
-> f a -> [f a] | |
bfgs f x0 = map fromVector $ bfgsV evaluate (toVector x0) | |
where | |
fromVector :: Vector a -> f a | |
fromVector = zipWithTV (\_ x -> x) x0 | |
toVector :: f a -> Vector a | |
toVector = fromList . F.toList | |
evaluate :: Vector a -> (a, Vector a) | |
evaluate x = | |
case grad' f (fromVector x) of | |
(obj, g) -> (obj, toVector g) | |
-- | Limited-memory BFGS | |
-- | |
-- https://en.wikipedia.org/wiki/Limited-memory_BFGS | |
lbfgs | |
:: forall f a. (Traversable f, Field a, Ord a, Normed (Vector a), Show a) | |
=> Int | |
-> (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) | |
-> f a -> [f a] | |
lbfgs m f x0 = map fromVector $ lbfgsV m evaluate (toVector x0) | |
where | |
fromVector :: Vector a -> f a | |
fromVector = zipWithTV (\_ x -> x) x0 | |
toVector :: f a -> Vector a | |
toVector = fromList . F.toList | |
evaluate :: Vector a -> (a, Vector a) | |
evaluate x = | |
case grad' f (fromVector x) of | |
(obj, g) -> (obj, toVector g) | |
bfgsV | |
:: forall a. (Field a, Ord a, Normed (Vector a), Show a) | |
=> (Vector a -> (a, Vector a)) | |
-> Vector a -> [Vector a] | |
bfgsV f x0 = go (ident n) alpha0 (x0, o0, g0) | |
where | |
n = VG.length x0 | |
(o0, g0) = f x0 | |
alpha0 :: a | |
alpha0 = realToFrac $ 1 / norm_2 g0 | |
epsilon :: Double | |
epsilon = 1e-5 | |
go :: Matrix a -> a -> (Vector a, a, Vector a) -> [Vector a] | |
go bInv alpha_ (x, o, g) = x : | |
if converged then | |
[] | |
else | |
case err of | |
Just e -> error (show e) | |
Nothing | |
| sy > 0 -> go bInv' 1.0 (x', o', g') | |
| otherwise -> error ("curvature condition failed: " ++ show sy) | |
where | |
converged :: Bool | |
converged = norm_2 g / max (norm_2 x) 1 <= epsilon | |
p :: Vector a | |
p = scale (-1) $ bInv #> g | |
(err, alpha, (x', o', g')) = LS.lineSearch LS.defaultParams f (x, o, g) p alpha_ | |
s, y :: Vector a | |
s = scale alpha p | |
y = g' `add` scale (-1) g | |
sy :: a | |
sy = s <.> y | |
bInv' :: Matrix a | |
bInv' = | |
bInv | |
`add` scale ((sy + y <.> (bInv #> y)) / sy**2) (s `outer` s) | |
`add` scale (-1 / sy) (((bInv #> y) `outer` s) `add` (s `outer` (y <# bInv))) | |
lbfgsV | |
:: forall a. (Field a, Ord a, Normed (Vector a), Show a) | |
=> Int | |
-> (Vector a -> (a, Vector a)) | |
-> Vector a -> [Vector a] | |
lbfgsV m f x0 = go Seq.empty alpha0 (x0, o0, g0) | |
where | |
(o0, g0) = f x0 | |
alpha0 :: a | |
alpha0 = realToFrac $ 1 / norm_2 g0 | |
epsilon :: Double | |
epsilon = 1e-5 | |
go :: Seq (Vector a, Vector a, a) -> a -> (Vector a, a, Vector a) -> [Vector a] | |
go hist alpha_ (x, o, g) = x : | |
if converged then | |
[] | |
else | |
case err of | |
Just e -> error (show e) | |
Nothing | |
| sy > 0 -> go (Seq.take m ((s,y,rho) <| hist)) 1.0 (x', o', g') | |
| otherwise -> error ("curvature condition failed: " ++ show sy) | |
where | |
converged :: Bool | |
converged = norm_2 g / max (norm_2 x) 1 <= epsilon | |
p :: Vector a | |
p = scale (-1) (f (F.toList hist) g) | |
where | |
f :: [(Vector a, Vector a, a)] -> Vector a -> Vector a | |
f ((s,y,rho) : xs) q = z `add` scale (alpha - beta) s | |
where | |
alpha = rho * (s <.> q) | |
z = f xs (q `add` scale (- alpha) y) | |
beta = rho * (y <.> z) | |
f [] q = | |
case Seq.viewl hist of | |
EmptyL -> q | |
(s, y, _rho) :< _ -> scale (s <.> y / y <.> y) q | |
(err, alpha, (x', o', g')) = LS.lineSearch LS.defaultParams f (x, o, g) p alpha_ | |
s, y :: Vector a | |
s = scale alpha p | |
y = g' `add` scale (-1) g | |
sy, rho :: a | |
sy = s <.> y | |
rho = 1 / sy | |
-- example from https://en.wikipedia.org/wiki/Gauss%E2%80%93Newton_algorithm | |
test_BFGS = bfgs f x0 | |
where | |
f :: (Fractional a, Floating a) => [a] -> a | |
f [beta1,beta2] = sum [(y - beta1*x / (beta2 + x))**2 | (x,y) <- zip xs ys] | |
where | |
xs = [0.038, 0.194, 0.425, 0.626, 1.253, 2.500, 3.740] | |
ys = [0.050, 0.127, 0.094, 0.2122, 0.2729, 0.2665, 0.3317] | |
x0 :: [Double] | |
x0 = [0.9, 0.2] | |
-- example from https://en.wikipedia.org/wiki/Gauss%E2%80%93Newton_algorithm | |
test_LBFGS = lbfgs 10 f x0 | |
where | |
f :: (Fractional a, Floating a) => [a] -> a | |
f [beta1,beta2] = sum [(y - beta1*x / (beta2 + x))**2 | (x,y) <- zip xs ys] | |
where | |
xs = [0.038, 0.194, 0.425, 0.626, 1.253, 2.500, 3.740] | |
ys = [0.050, 0.127, 0.094, 0.2122, 0.2729, 0.2665, 0.3317] | |
x0 :: [Double] | |
x0 = [0.9, 0.2] | |
rosenbrock [x,y] = sq (1 - x) + 100 * sq (y - sq x) | |
where | |
-- Note that 'sq x = x * x' did not work with Kahn mode. | |
-- https://github.com/ekmett/ad/pull/84 | |
sq x = x ** 2 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment