Last active
February 5, 2018 10:30
-
-
Save jonathanBieler/47d9ae7e95e7ca0f7352de8f84827ae3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module A | |
using Optim | |
import Optim: ZerothOrderOptimizer, ZerothOrderState, initial_state, update_state!, trace!, assess_convergence, AbstractOptimizerState, update!, value | |
struct RandomSampler <: ZerothOrderOptimizer | |
σ::Float64 | |
end | |
RandomSampler(; σ=1e-1) = RandomSampler(σ) | |
type RandomSamplerState{T,N} <: ZerothOrderState | |
x::Array{T,N} | |
x_previous::Array{T,N} | |
f_x_previous::T | |
end | |
function initial_state(method::RandomSampler, options, d, initial_x) | |
# prepare cache variables etc here | |
RandomSamplerState(initial_x,initial_x,Inf) | |
end | |
function update_state!{T}(d, state::RandomSamplerState{T}, method::RandomSampler) | |
x = state.x + method.σ*randn(length(state.x)) | |
f_x = value(d, x) | |
if f_x < state.f_x_previous | |
state.f_x_previous = f_x | |
state.x = x | |
end | |
false # should the procedure force quit? | |
end | |
function trace!(tr, d, state, iteration, method::RandomSampler, options) | |
dt = Dict() | |
if options.extended_trace | |
dt["x"] = copy(state.x) | |
end | |
g_norm = 0.0 | |
update!(tr, | |
iteration, | |
value(d), | |
g_norm, | |
dt, | |
options.store_trace, | |
options.show_trace, | |
options.show_every, | |
options.callback) | |
end | |
function assess_convergence(state::RandomSamplerState, d, options) | |
Optim.default_convergence_assessment(state, d, options) | |
end | |
end | |
f = x -> sum(x.^2) + π | |
mfit = optimize(f,rand(2),A.RandomSampler(),Optim.Options(iterations=500,store_trace=false,extended_trace=false)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment