Last active
December 18, 2019 12:26
-
-
Save jonathanBieler/7000344c67dc715c8a52601db7985b5a to your computer and use it in GitHub Desktop.
CMAES.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module CMAES | |
using Optim, Distributions, Parameters | |
import Optim: ZerothOrderOptimizer, ZerothOrderState, initial_state, update_state!, | |
trace!, assess_convergence, AbstractOptimizerState, update!, value, value!, pick_best_x, | |
pick_best_f | |
struct CMA <: Optim.ZerothOrderOptimizer | |
λ::Int | |
σ::Float64 | |
end | |
type CMAState{T} <: Optim.ZerothOrderState | |
x::Array{T,1} | |
x_previous::Array{T,1} #this is used by default convergence test | |
f_x::T | |
f_x_previous::T | |
# | |
w::Vector{T} | |
D::Int | |
μ_w::T | |
c_σ::T | |
d_σ::T | |
c_c::T | |
c_1::T | |
c_μ::T | |
# | |
p_σ::Vector{T} | |
p_c::Vector{T} | |
C::Matrix{T} | |
m::Vector{T} | |
chi_D::T | |
σ::T | |
σ_0::T | |
λ::Int | |
t::Int | |
xs::Vector{Vector{T}} | |
end | |
∑(x) = sum(x) | |
weights(μ) = [(log(μ+1)-log(i)) for i=1:μ] / ∑( log(μ+1)-log(j) for j=1:μ ) | |
## | |
function initial_state(method::CMA, options, d, xinit) | |
λ = method.λ | |
D = length(xinit) | |
μ = floor(Int,λ/2) | |
w = weights(μ) | |
D, μ_w, c_σ, d_σ, c_c, c_1, c_μ = init_constants(xinit,λ,w,μ) | |
p_σ, p_c = zeros(D), zeros(D) | |
C = diagm(ones(D)) | |
m = xinit | |
Normal(C) = rand(MultivariateNormal(zeros(D),C)) | |
chi_D = √D*(1-1/(4*D) + 1/(21*D^2)) | |
xs = [zeros(D) for i=1:μ] | |
σ_0 = copy(method.σ) | |
CMAState( | |
xinit,xinit,Inf,Inf,w, | |
D, μ_w, c_σ, d_σ, c_c, c_1, c_μ, | |
p_σ, p_c,C,m,chi_D,method.σ,σ_0,λ,1,xs | |
) | |
end | |
function trace!(tr, d, state, iteration, method::CMA, options) | |
dt = Dict() | |
if options.extended_trace | |
dt["x"] = state.xs | |
end | |
g_norm = 0.0 | |
update!(tr, | |
iteration, | |
value(d), | |
g_norm, | |
dt, | |
options.store_trace, | |
options.show_trace, | |
options.show_every, | |
options.callback) | |
end | |
function update_state!{T}(d, state::CMAState{T}, method::CMA) | |
@unpack w,D,μ_w,c_σ,d_σ,c_c,c_1,c_μ,p_σ,p_c,C,m,chi_D,σ,t,xs = state | |
state.f_x_previous = state.f_x | |
copy!(state.x_previous, state.x) | |
λ = method.λ | |
μ = floor(Int,λ/2) | |
x = [m + σ*rand(MultivariateNormal(zeros(D),C)) for i=1:λ] | |
fx = zeros(T,λ) | |
for i=1:λ | |
value!(d,x[i]) | |
fx[i] = value(d) | |
end | |
idx = sortperm(fx) | |
x, fx = x[idx], fx[idx] | |
m_t = m | |
m = ∑( w[i] * x[i] for i=1:μ) | |
Δm = m-m_t | |
p_σ = (1-c_σ)*p_σ + √(c_σ*(2-c_σ)*μ_w) * C^(-1/2) * Δm/σ | |
h_σ = norm(p_σ) < √(1-(1-c_σ)^(2*(t+1)))*(1.4 + 2/(D+1)) ? 1.0 : 0.0 | |
p_c = (1-c_c)*p_c + h_σ*√(c_σ*(2 - c_σ)*μ_w) * Δm/σ | |
C = (1-c_1-c_µ + (1-h_σ)*c_1*c_c*(2-c_c)) * C + | |
c_1 * p_c * p_c' + | |
c_μ * ∑( w[i]/σ^2*( (x[i]-m) * (x[i]-m)') for i=1:μ) | |
C = (C+C')/2 #keep symmetric part | |
σ = σ * exp(c_σ/d_σ*(norm(p_σ)/chi_D -1)) | |
t += 1 | |
state.x = x[1] | |
state.f_x = fx[1] | |
xs = x | |
@pack state = w,D,μ_w,c_σ,d_σ,c_c,c_1,c_μ,p_σ,p_c,C,m,chi_D,σ,t,xs | |
norm(p_σ) < 1e-150 && return true | |
false # should the procedure force quit? | |
end | |
function assess_convergence(state::CMAState, d, options) | |
x_converged, f_converged, g_converged, converged, f_increased = Optim.default_convergence_assessment(state, d, options) | |
f_increased = false #disable this one, since error can increase with CMAES | |
@unpack D, λ, p_c, σ, σ_0,t,C = state | |
d,_ = eig(C) #FIXME maybe a bit costly | |
MaxIter = t > 100 + 50*(D+3)*2/√(λ) | |
TolX = all(p_c * σ/σ_0 .< 1e-12) | |
TolUpSigma = σ/σ_0 > 10^20 * √(maximum(d)) | |
ConditionCov = abs(maximum(d)) / abs(minimum(d)) > 10^14 | |
converged = false #TolX #|| TolUpSigma || ConditionCov | |
x_converged, f_converged, g_converged, converged, f_increased | |
end | |
pick_best_f(f_increased, state::CMAState, d) = begin println("f_x: $(state.f_x)"); state.f_x end | |
pick_best_x(f_increased, state::CMAState) = state.x | |
function init_constants(xinit,λ,w,μ) | |
D = length(xinit) | |
μ_w = 1.0 / sum(w.^2) | |
c_σ = (μ_w + 2.0) / (D + μ_w + 5.0) | |
d_σ = 1.0 + c_σ + 2.0*max(0, √((μ_w-1)/(D+1)) -1) | |
c_c = (4 + μ_w/D)/(D + 4 + 2μ_w/D) | |
c_1 = 2/((D+1.3)^2 + μ_w) | |
c_μ = min(1-c_1,2*(μ_w-2+1/μ_w)/((D+2)^2 +μ_w)) | |
D, μ_w, c_σ, d_σ, c_c, c_1, c_μ | |
end | |
end | |
## | |
## | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment