Skip to content

Instantly share code, notes, and snippets.

@vankesteren
Last active February 25, 2023 05:35
Show Gist options
  • Star 9 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save vankesteren/96207abcd16ecd01a2491bcbec12c73f to your computer and use it in GitHub Desktop.
Save vankesteren/96207abcd16ecd01a2491bcbec12c73f to your computer and use it in GitHub Desktop.
Julia implementation of Adam optimizer
module Adamopt
# This is a module implementing vanilla Adam (https://arxiv.org/abs/1412.6980).
export Adam, step!
# Struct containing all necessary info
mutable struct Adam
theta::AbstractArray{Float64} # Parameter array
loss::Function # Loss function
grad::Function # Gradient function
m::AbstractArray{Float64} # First moment
v::AbstractArray{Float64} # Second moment
b1::Float64 # Exp. decay first moment
b2::Float64 # Exp. decay second moment
a::Float64 # Step size
eps::Float64 # Epsilon for stability
t::Int # Time step (iteration)
end
# Outer constructor
function Adam(theta::AbstractArray{Float64}, loss::Function, grad::Function)
m = zeros(size(theta))
v = zeros(size(theta))
b1 = 0.9
b2 = 0.999
a = 0.001
eps = 1e-8
t = 0
Adam(theta, loss, grad, m, v, b1, b2, a, eps, t)
end
# Step function with optional keyword arguments for the data passed to grad()
function step!(opt::Adam; data...)
opt.t += 1
gt = opt.grad(opt.theta; data...)
opt.m = opt.b1 .* opt.m + (1 - opt.b1) .* gt
opt.v = opt.b2 .* opt.v + (1 - opt.b2) .* gt .^ 2
mhat = opt.m ./ (1 - opt.b1^opt.t)
vhat = opt.v ./ (1 - opt.b2^opt.t)
opt.theta -= opt.a .* (mhat ./ (sqrt.(vhat) .+ opt.eps))
end
end
@vankesteren
Copy link
Author

vankesteren commented Nov 21, 2019

To use the Adamopt module

using HTTP: request

"https://gist.githubusercontent.com/vankesteren/96207abcd16ecd01a2491bcbec12c73f/raw/1b59af6962a1107db5873eba59054acc3f9a8aac/Adamopt.jl" |>
  url -> request("GET", url) |> 
  res -> String(res.body) |> 
  str -> include_string(Main, str)

using .Adamopt

@vankesteren
Copy link
Author

vankesteren commented Nov 21, 2019

#### EXAMPLE ####
# let's use this Adam implementation for linear regression
using .Adamopt
using Random
N = 100
x = randn(N, 2)
y = x * [9, -3] + randn(N)

# loss function with data kwargs
function mse(b; x = x, y = y)
  res = y - x*b
  res'res
end

# gradient function with data kwargs
function grad(b; x = x, y = y)
  -2 .* x' * (y - x * b)
end

# deterministic adam
dopt   = Adam([0.0, 0.0], mse, grad)
dopt.a = 0.01
for i = 1:5000
  step!(dopt)
  print(string("Step: ", dopt.t, " | Loss: ", dopt.loss(dopt.theta), "\n"))
end

# stochastic adam
sopt   = Adam([0.0, 0.0], mse, grad)
sopt.a = 0.01
batch  = 12
epochs = 350
for e = 1:epochs
  pidx = Random.randperm(N)
  while (length(pidx) > 0)
    idx = [ pop!(pidx) for i in 1:batch if length(pidx) > 0 ]
    step!(sopt; x = x[idx, :], y = y[idx, :])
  end
  print(string("Step: ", sopt.t, " | Loss: ", sopt.loss(sopt.theta), "\n"))
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment