Skip to content

Instantly share code, notes, and snippets.

@snowyday
Last active January 16, 2018 04:26
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save snowyday/19b959b268d3af7785b2dd0e2f37f6bb to your computer and use it in GitHub Desktop.
Save snowyday/19b959b268d3af7785b2dd0e2f37f6bb to your computer and use it in GitHub Desktop.
Eve: Improving Stochastic Gradient Descent with Feedback
import math
from torch.optim import Optimizer
class Eve(Optimizer):
"""Implements Eve (Adam with feedback) algorithm.
It has been proposed in `Improving Stochastic Gradient Descent with Feedback, `_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-2)
betas (Tuple[float, float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999, 0.999))
thr ((Tuple[float, float], optional): lower and upper threshold for relative change
(default: (0.1, 10))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
.. _Eve\: Improving Stochastic Gradient Descent with Feedback
https://arxiv.org/abs/1611.01505
"""
def __init__(self, params, lr=1e-2, betas=(0.9, 0.999, 0.999), eps=1e-8, thr=(0.1, 10), weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, thr=thr, weight_decay=weight_decay)
super(Eve, self).__init__(params, defaults)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
if closure is not None:
loss = closure()
loss_val = loss.data[0]
else:
raise ValueError("Eve requires a value of the loss function.")
for group in self.param_groups:
for p in group['params']:
grad = p.grad.data
state = self.state[id(p)]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = grad.new().resize_as_(grad).zero_()
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()
# Previous loss value
state['loss_hat_prev'] = loss_val
# Feed-back from the loss function
state['decay_rate'] = 1
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2, beta3 = group['betas']
thl, thu = group['thr']
loss_hat_prev = state['loss_hat_prev']
state['step'] += 1
if group['weight_decay'] != 0:
grad = grad.add(group['weight_decay'], p.data)
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
if state['step'] > 1:
if loss_val >= loss_hat_prev:
lower_bound = thl + 1
upper_bound = thu + 1
else:
lower_bound = 1 / (thu + 1)
upper_bound = 1 / (thl + 1)
clip = min(max(lower_bound, loss_val / loss_hat_prev), upper_bound)
loss_hat = clip * loss_hat_prev
relative_change = abs(loss_hat - loss_hat_prev) / min(loss_hat, loss_hat_prev)
state['decay_rate'] = beta3 * state['decay_rate'] + (1 - beta3) * relative_change
state['loss_hat_prev'] = loss_hat
denom = exp_avg_sq.sqrt().mul_(state['decay_rate']).add_(group['eps'])
p.data.addcdiv_(-step_size, exp_avg, denom)
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment