Skip to content

Instantly share code, notes, and snippets.

@jonathanBieler
Last active December 18, 2019 12:26
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jonathanBieler/7000344c67dc715c8a52601db7985b5a to your computer and use it in GitHub Desktop.
Save jonathanBieler/7000344c67dc715c8a52601db7985b5a to your computer and use it in GitHub Desktop.
CMAES.jl
module CMAES
using Optim, Distributions, Parameters
import Optim: ZerothOrderOptimizer, ZerothOrderState, initial_state, update_state!,
trace!, assess_convergence, AbstractOptimizerState, update!, value, value!, pick_best_x,
pick_best_f
struct CMA <: Optim.ZerothOrderOptimizer
λ::Int
σ::Float64
end
type CMAState{T} <: Optim.ZerothOrderState
x::Array{T,1}
x_previous::Array{T,1} #this is used by default convergence test
f_x::T
f_x_previous::T
#
w::Vector{T}
D::Int
μ_w::T
c_σ::T
d_σ::T
c_c::T
c_1::T
c_μ::T
#
p_σ::Vector{T}
p_c::Vector{T}
C::Matrix{T}
m::Vector{T}
chi_D::T
σ::T
σ_0::T
λ::Int
t::Int
xs::Vector{Vector{T}}
end
∑(x) = sum(x)
weights(μ) = [(log(μ+1)-log(i)) for i=1:μ] / ∑( log(μ+1)-log(j) for j=1:μ )
##
function initial_state(method::CMA, options, d, xinit)
λ = method.λ
D = length(xinit)
μ = floor(Int,λ/2)
w = weights(μ)
D, μ_w, c_σ, d_σ, c_c, c_1, c_μ = init_constants(xinit,λ,w,μ)
p_σ, p_c = zeros(D), zeros(D)
C = diagm(ones(D))
m = xinit
Normal(C) = rand(MultivariateNormal(zeros(D),C))
chi_D = √D*(1-1/(4*D) + 1/(21*D^2))
xs = [zeros(D) for i=1:μ]
σ_0 = copy(method.σ)
CMAState(
xinit,xinit,Inf,Inf,w,
D, μ_w, c_σ, d_σ, c_c, c_1, c_μ,
p_σ, p_c,C,m,chi_D,method.σ,σ_0,λ,1,xs
)
end
function trace!(tr, d, state, iteration, method::CMA, options)
dt = Dict()
if options.extended_trace
dt["x"] = state.xs
end
g_norm = 0.0
update!(tr,
iteration,
value(d),
g_norm,
dt,
options.store_trace,
options.show_trace,
options.show_every,
options.callback)
end
function update_state!{T}(d, state::CMAState{T}, method::CMA)
@unpack w,D,μ_w,c_σ,d_σ,c_c,c_1,c_μ,p_σ,p_c,C,m,chi_D,σ,t,xs = state
state.f_x_previous = state.f_x
copy!(state.x_previous, state.x)
λ = method.λ
μ = floor(Int,λ/2)
x = [m + σ*rand(MultivariateNormal(zeros(D),C)) for i=1:λ]
fx = zeros(T,λ)
for i=1:λ
value!(d,x[i])
fx[i] = value(d)
end
idx = sortperm(fx)
x, fx = x[idx], fx[idx]
m_t = m
m = ∑( w[i] * x[i] for i=1:μ)
Δm = m-m_t
p_σ = (1-c_σ)*p_σ + √(c_σ*(2-c_σ)*μ_w) * C^(-1/2) * Δm/σ
h_σ = norm(p_σ) < √(1-(1-c_σ)^(2*(t+1)))*(1.4 + 2/(D+1)) ? 1.0 : 0.0
p_c = (1-c_c)*p_c + h_σ*√(c_σ*(2 - c_σ)*μ_w) * Δm/σ
C = (1-c_1-c_µ + (1-h_σ)*c_1*c_c*(2-c_c)) * C +
c_1 * p_c * p_c' +
c_μ * ∑( w[i]/σ^2*( (x[i]-m) * (x[i]-m)') for i=1:μ)
C = (C+C')/2 #keep symmetric part
σ = σ * exp(c_σ/d_σ*(norm(p_σ)/chi_D -1))
t += 1
state.x = x[1]
state.f_x = fx[1]
xs = x
@pack state = w,D,μ_w,c_σ,d_σ,c_c,c_1,c_μ,p_σ,p_c,C,m,chi_D,σ,t,xs
norm(p_σ) < 1e-150 && return true
false # should the procedure force quit?
end
function assess_convergence(state::CMAState, d, options)
x_converged, f_converged, g_converged, converged, f_increased = Optim.default_convergence_assessment(state, d, options)
f_increased = false #disable this one, since error can increase with CMAES
@unpack D, λ, p_c, σ, σ_0,t,C = state
d,_ = eig(C) #FIXME maybe a bit costly
MaxIter = t > 100 + 50*(D+3)*2/√(λ)
TolX = all(p_c * σ/σ_0 .< 1e-12)
TolUpSigma = σ/σ_0 > 10^20 * √(maximum(d))
ConditionCov = abs(maximum(d)) / abs(minimum(d)) > 10^14
converged = false #TolX #|| TolUpSigma || ConditionCov
x_converged, f_converged, g_converged, converged, f_increased
end
pick_best_f(f_increased, state::CMAState, d) = begin println("f_x: $(state.f_x)"); state.f_x end
pick_best_x(f_increased, state::CMAState) = state.x
function init_constants(xinit,λ,w,μ)
D = length(xinit)
μ_w = 1.0 / sum(w.^2)
c_σ = (μ_w + 2.0) / (D + μ_w + 5.0)
d_σ = 1.0 + c_σ + 2.0*max(0, √((μ_w-1)/(D+1)) -1)
c_c = (4 + μ_w/D)/(D + 4 + 2μ_w/D)
c_1 = 2/((D+1.3)^2 + μ_w)
c_μ = min(1-c_1,2*(μ_w-2+1/μ_w)/((D+2)^2 +μ_w))
D, μ_w, c_σ, d_σ, c_c, c_1, c_μ
end
end
##
##
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment