Skip to content

Instantly share code, notes, and snippets.

@jonathanBieler
Created January 12, 2018 09:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jonathanBieler/ed2ae8868e7b317c9e6d2db86f6ed2b9 to your computer and use it in GitHub Desktop.
Save jonathanBieler/ed2ae8868e7b317c9e6d2db86f6ed2b9 to your computer and use it in GitHub Desktop.
module A
using Optim
import Optim: FirstOrderOptimizer, initial_state, update_state!, trace!, assess_convergence, AbstractOptimizerState, update!, value
struct MinimalGradientDescent <: FirstOrderOptimizer
η::Float64
end
MinimalGradientDescent(; η=1e-1) = MinimalGradientDescent(η)
type MinimalGradientDescentState{T,N} <: AbstractOptimizerState
x::Array{T,N}
x_previous::Array{T,N}
f_x_previous::T
end
function initial_state(method::MinimalGradientDescent, options, d, initial_x)
# prepare cache variables etc here
MinimalGradientDescentState(initial_x,initial_x,Inf)
end
function update_state!{T}(d, state::MinimalGradientDescentState{T}, method::MinimalGradientDescent)
state.x += -method.η * gradient(d)
false # should the procedure force quit?
end
function trace!(tr, d, state, iteration, method::MinimalGradientDescent, options)
dt = Dict()
if options.extended_trace
dt["x"] = copy(state.x)
dt["g(x)"] = copy(gradient(d))
end
g_norm = vecnorm(gradient(d), Inf)
update!(tr,
iteration,
value(d),
g_norm,
dt,
options.store_trace,
options.show_trace,
options.show_every,
options.callback)
end
function assess_convergence(state::MinimalGradientDescentState, d, options)
Optim.default_convergence_assessment(state, d, options)
end
end
f = x -> sum(x.^2) + π
mfit = optimize(f,rand(2),A.MinimalGradientDescent(),Optim.Options(iterations=500,store_trace=true,extended_trace=true))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment