Last active
August 11, 2016 13:36
-
-
Save sslipchenko/b923c9d2cac8692e614daeb4dd1910b8 to your computer and use it in GitHub Desktop.
Optimization stages of Sequential Minimal Optimization (SMO) algorithm in F#
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Copyright 2016 Serge Slipchenko (Serge.Slipchenko@gmail.com) | |
// | |
// Licensed under the Apache License, Version 2.0 (the "License"); | |
// you may not use this file except in compliance with the License. | |
// You may obtain a copy of the License at | |
// | |
// http://www.apache.org/licenses/LICENSE-2.0 | |
// | |
// Unless required by applicable law or agreed to in writing, software | |
// distributed under the License is distributed on an "AS IS" BASIS, | |
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
// See the License for the specific language governing permissions and | |
// limitations under the License. | |
namespace Semagle.MachineLearning.SVM | |
open LanguagePrimitives | |
/// Implementation of Sequential Minimal Optimization (SMO) algorithm | |
module SMO = | |
[<Literal>] | |
let private tau = 1e-12 | |
[<Literal>] | |
let private not_found = -1 | |
type WSSStrategy = MaximalViolatingPair | SecondOrderInformation | |
type C_SVM = { iterations : int; epsilon : float; C : float; strategy : WSSStrategy } | |
let C_SVC (X : 'X[]) (Y : float[]) (K : Kernel<'X>) (parameters : C_SVM) = | |
if Array.length X <> Array.length Y then | |
invalidArg "X and Y" "have different lengths" | |
let info = printfn | |
let N = Array.length X | |
let C = Array.create N parameters.C | |
let epsilon = parameters.epsilon | |
let A = Array.zeroCreate N // optimization variables | |
let G = Array.create N -1.0 // gradient | |
// helper functions | |
let inline _y_gf i = -G.[i]*Y.[i] | |
let inline Q i j = (K X.[i] X.[j])*Y.[i]*Y.[j] | |
let I_up () = seq { 0..N-1 } |> Seq.filter (fun i -> (Y.[i] = +1.0 && A.[i] < C.[i]) || (Y.[i] = -1.0 && A.[i] > 0.0)) | |
let I_low () = seq { 0..N-1 } |> Seq.filter (fun i -> (Y.[i] = -1.0 && A.[i] < C.[i]) || (Y.[i] = +1.0 && A.[i] > 0.0)) | |
let inline maxUp () = | |
let ts = I_up () | |
if Seq.isEmpty ts then not_found else Seq.maxBy _y_gf ts | |
let inline minLow () = | |
let ts = I_low () | |
if Seq.isEmpty ts then not_found else Seq.minBy _y_gf ts | |
let inline minLowTo s = | |
let ts = I_low () |> Seq.filter (fun t -> _y_gf t < _y_gf s) | |
let objective t = | |
let a_ts = (Q t t) + (Q s s) - 2.0*(Q t s)*Y.[t]*Y.[s] | |
let b_ts = _y_gf t - _y_gf s | |
-b_ts*b_ts/(if a_ts > 0.0 then a_ts else tau) | |
if Seq.isEmpty ts then not_found else Seq.minBy objective ts | |
/// Maximal violating pair working set selection strategy | |
let maximalViolatingPair () = | |
let i = maxUp() | |
if i = not_found then None else Some (i, minLow()) | |
/// Second order information working set selection strategy | |
let secondOrderInformation () = | |
let i = maxUp() | |
if i = not_found then None else Some (i, minLowTo i) | |
let selectWorkingSet = | |
match parameters.strategy with | |
| MaximalViolatingPair -> maximalViolatingPair | |
| SecondOrderInformation -> secondOrderInformation | |
/// Solve an optimization sub-problem | |
let inline solve i j = | |
let a_ij = (Q i i) + (Q j j) - 2.0*(Q i j)*Y.[i]*Y.[j] | |
if Y.[i] <> Y.[j] then | |
let delta = (-G.[i]-G.[j])/(if a_ij < 0.0 then tau else a_ij) | |
let diff = A.[i] - A.[j] | |
match (A.[i] + delta, A.[j] + delta) with | |
| _, a_j when diff > 0.0 && a_j < 0.0 -> (diff, 0.0) | |
| a_i, _ when diff <= 0.0 && a_i < 0.0 -> (0.0, diff) | |
| _, a_j when diff <= C.[i] - C.[j] && a_j > C.[j] -> (C.[j]+diff, C.[j]) | |
| a_i, _ when diff > C.[i] - C.[j] && a_i > C.[i] -> (C.[i], C.[i] - diff) | |
| a_i, a_j -> a_i, a_j | |
else | |
let delta = (G.[i]-G.[j])/(if a_ij < 0.0 then tau else a_ij) | |
let sum = A.[i] + A.[j] | |
match (A.[i] - delta, A.[j] + delta) with | |
| a_i, _ when sum > C.[i] && a_i > C.[i] -> (C.[i], sum - C.[i]) | |
| _, a_j when sum <= C.[i] && a_j < 0.0 -> (sum, 0.0) | |
| _, a_j when sum > C.[j] && a_j > C.[j] -> (sum - C.[j], C.[j]) | |
| a_i, _ when sum <= C.[j] && a_i < 0.0 -> (0.0, sum) | |
| a_i, a_j -> a_i, a_j | |
/// Update gradient | |
let inline updateG i j a_i a_j = | |
for t in 0..N-1 do | |
G.[t] <- G.[t] + (Q t i)*(a_i - A.[i]) + (Q t j)*(a_j - A.[j]) | |
/// Check stop conditions | |
let stopCriterion () = | |
let inline m y = I_up () |> Seq.filter (fun i -> Y.[i] = y) |> Seq.map _y_gf |> Seq.max | |
let inline M y = I_low () |> Seq.filter (fun i -> Y.[i] = y) |> Seq.map _y_gf |> Seq.min | |
((m +1.0) - (M +1.0) < epsilon) || ((m -1.0) - (M -1.0) < epsilon) | |
/// Sequential Minimal Optimization (SMO) Algorithm | |
let rec optimize k = | |
if k < parameters.iterations then | |
/// 1. Find a pair of elements that violate the optimality condition | |
match selectWorkingSet () with | |
| Some (i, j) -> | |
/// 2. Solve the optimization sub-problem | |
let a_i, a_j = solve i j | |
/// 3. Update the gradient and the solution | |
updateG i j a_i a_j | |
A.[i] <- a_i; A.[j] <- a_j | |
if stopCriterion() then k else optimize (k + 1) | |
| None -> k | |
else | |
failwith "Exceeded iterations limit" | |
let iterations = optimize 0 | |
info "#iterations = %d" iterations | |
/// Reconstruction of hyperplane bias | |
let bias = | |
let mutable b = 0.0 | |
let mutable M = 0 | |
for i = 0 to N-1 do | |
if 0.0 < A.[i] && A.[i] < C.[i] then | |
b <- b + _y_gf i | |
M <- M + 1 | |
DivideByInt b M | |
/// Remove support vectors with A.[i] = 0.0 and compute Y.[i]*A.[i] | |
let N' = Array.sumBy (fun a -> if a <> 0.0 then 1 else 0) A | |
let X' = Array.zeroCreate N' | |
let A' = Array.zeroCreate N' | |
let mutable j = 0 | |
for i = 0 to N-1 do | |
if A.[i] <> 0.0 then | |
X'.[j] <- X.[i] | |
A'.[j] <- Y.[i]*A.[i] | |
j <- j + 1 | |
info "support vectors = %d" N' | |
SVM(K,X',A',bias) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment