Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active February 26, 2017 03:56
Show Gist options
  • Save crowsonkb/8da6cc4bfc5e99565ea7f897700a0bc0 to your computer and use it in GitHub Desktop.
Save crowsonkb/8da6cc4bfc5e99565ea7f897700a0bc0 to your computer and use it in GitHub Desktop.
--[[ An experimental quasi-Newton optimizer.
Incorporates Hessian damping, momentum, and per-feature learning rate scaling.
Also implements optional polynomial-decay averaging (similar to ASGD).
ARGS:
- 'opfunc' : a function that takes a single input (X), the point
of a evaluation, and returns f(X) and df/dX
- 'x' : the initial point
- 'config` : a table with configuration parameters for the optimizer
- 'config.averagingDecay' : if >= 0, averaging decay exponent. if < 0,
disables averaging
- 'config.epsilon' : for numerical stability
- 'config.learningRate' : learning rate
- 'config.momentum' : momentum
- 'config.nCorrection' : the maximum number of L-BFGS corrections
- 'config.phi' : Hessian damping
RETURN:
- `x` : the new x vector
- `f(x)` : the function, evaluated after the update
- `average` : the averaged parameter vector
(Katherine Crowson, 2016)
]]
function dmsqn(opfunc, x, config, state)
-- Configuration
local config = config or {}
local state = state or config
local always_div_g2 = config.alwaysDivG2 or true
local avg_decay = config.averagingDecay or -1
local eps = config.epsilon or 1e-8
local lr = config.learningRate or 1e-4
local momentum = config.momentum or 0.9
local nCorrection = config.nCorrection or 10
local phi = config.phi or 0.2
-- Initialization
state.t = state.t or 0
-- L-BFGS memory
state.sk = state.sk or {}
state.yk = state.yk or {}
-- Gradient first moment accumulator
state.g1 = state.g1 or x.new(x:size()):zero()
-- Gradient second moment accumulator
state.g2 = state.g2 or x.new(x:size()):fill(eps)
-- Parameter vector first moment accumulator
state.p1 = state.p1 or x.new(x:size()):zero()
-- Reusable buffers for s and y
state.s = state.s or x.new(x:size())
state.y = state.y or x.new(x:size())
local s, y = state.s, state.y
-- Reusable temporary buffer
state.tmp = state.tmp or x.new(x:size())
local tmp = state.tmp
-- Reusable buffers for s's and y's
if not state.s_bufs then
state.s_bufs = state.s_bufs or x.new(nCorrection, x:nElement()):split(1)
state.y_bufs = state.y_bufs or x.new(nCorrection, x:nElement()):split(1)
for i=1,#state.s_bufs do
state.s_bufs[i] = state.s_bufs[i]:squeeze(1)
state.y_bufs[i] = state.y_bufs[i]:squeeze(1)
end
end
-- First step: set initial state
if not state.g then
_,state.g = opfunc(x)
state.g1:add(state.g)
state.g2:addcmul(state.g, state.g)
end
-- Decay first moment of gradient
state.g1:mul(momentum)
-- Compute step with L-BFGS two-loop recursion
s:add(state.g1, state.g) -- Nesterov momentum
local k = #state.sk
local rho = torch.zeros(nCorrection)
for i = 1,k do
rho[i] = 1 / state.sk[i]:dot(state.yk[i])
end
local alpha = torch.zeros(nCorrection)
for i = k,1,-1 do
alpha[i] = state.sk[i]:dot(s) * rho[i]
s:add(-alpha[i], state.yk[i])
end
if not always_div_g2 and k > 0 then
local sy = state.sk[k]:dot(state.yk[k])
local yy = state.yk[k]:dot(state.yk[k])
s:mul(sy / yy)
else
s:cdiv(tmp:sqrt(state.g2))
end
for i = 1,k do
local beta = state.yk[i]:dot(s) * rho[i]
s:add(alpha[i] - beta, state.sk[i])
end
-- Two-loop recursion done: take step and update moments
s:mul(-lr)
--print(state.t, tmp:abs(s):mean())
x:add(s)
fx, g = opfunc(x)
state.g1:add(g)
state.g2:addcmul(g, g)
-- Compute y
y:add(g, -1, state.g) -- y = new gradient - old gradient
y:mul(1-phi):add(phi, s) -- Hessian damping
y:cmul(tmp:sqrt(state.g2)) -- Scale by Adagrad scaling matrix
-- Store gradient
state.g:copy(g)
-- Store curvature pair
if #state.sk == nCorrection then
-- Shift history by one
local removed_s = table.remove(state.sk, 1)
local removed_y = table.remove(state.yk, 1)
table.insert(state.s_bufs, removed_s)
table.insert(state.y_bufs, removed_y)
end
table.insert(state.sk, table.remove(state.s_bufs):copy(s))
table.insert(state.yk, table.remove(state.y_bufs):copy(y))
-- Return x*, f(x) after step
state.t = state.t + 1
if avg_decay < 0 then
return x, {fx}, x
end
-- Polynomial-decay averaging
local weight = (1+avg_decay) / (state.t+avg_decay)
state.p1:mul(1-weight):add(weight, x)
return x, {fx}, state.p1
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment