Skip to content

Instantly share code, notes, and snippets.

@Sixzero
Created June 24, 2022 11:37
Show Gist options
  • Save Sixzero/8afa444b2fe7bf3f20ed72258c08a5ac to your computer and use it in GitHub Desktop.
Save Sixzero/8afa444b2fe7bf3f20ed72258c08a5ac to your computer and use it in GitHub Desktop.
Zygote simple local min problem
X, Y = (Float32[-0.31125240132442067; 0.8163067649323273;;;],
Float32[5.7064323; 2.599511;;;])
# w = randn(1,1,1) .* ones(2,1,1)
b = Float32[25.510088, ]
# b = randn(1,1,1) .* ones(2,1,1)
w = Float32[0.15980364, ]
modl(X,w,b) = begin
(X .+ b) .* w
end
loss(Y, y) = sum((y .- Y).^2)
using Zygote
using Flux.Optimise: update!, Descent, ADAM, Momentum
opt = ADAM(0.01, (0.8,0.99)) # Gradient descent with learning rate 0.1
opt = Momentum(0.0001, 0.8) # Gradient descent with learning rate 0.1
# opt = Descent(0.0001, ) # Gradient descent with learning rate 0.1
println("Start")
@time for i in 1:40000
grad = gradient((w,b)->loss(modl(X,w,b), Y), w,b)
@show loss(modl(X,w,b), Y)
update!(opt, w, grad[1])
update!(opt, b, grad[2])
end
@show loss(modl(X,w,b), Y)
@Sixzero
Copy link
Author

Sixzero commented Jun 25, 2022

Adamoptimizer copied for modification.

using Flux.Optimise: AbstractOptimiser

mutable struct MyOpt <: AbstractOptimiser
eta::Float64
beta::Tuple{Float64,Float64}
epsilon::Float64
state::IdDict{Any, Any}
end
MyOpt(η::Real = 0.001, β::Tuple = (0.9, 0.999), ϵ::Real = EPS) = MyOpt(η, β, ϵ, IdDict())
MyOpt(η::Real, β::Tuple, state::IdDict) = MyOpt(η, β, EPS, state)

function apply!(o::MyOpt, x, Δ)
η, β = o.eta, o.beta

mt, vt, βp = get!(o.state, x) do
(zero(x), zero(x), Float64[β[1], β[2]])
end :: Tuple{typeof(x),typeof(x),Vector{Float64}}

@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ)
@. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + o.epsilon) * η
βp .= βp .* β

return Δ
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment