Skip to content

Instantly share code, notes, and snippets.

@kashif
Created March 16, 2018 11:32
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save kashif/ecbe62c34a026b7d10f3312d0300a29d to your computer and use it in GitHub Desktop.
AccSGD optimizer for keras
class AccSGD(Optimizer):
"""AccSGD optimizer.
Arguments:
lr (float): learning rate
kappa (float, optional): ratio of long to short step (default: 1000)
xi (float, optional): statistical advantage parameter (default: 10)
smallConst (float, optional): any value <=1 (default: 0.7)
# References
- [Accelerating Stochastic Gradient Descent](https://arxiv.org/abs/1704.08227)
"""
def __init__(self, lr=0.1, kappa=1000.0, xi=10.0, smallConst=0.7, **kwargs):
super(AccSGD, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.lr = K.variable(lr, name='lr')
self.kappa = K.variable(kappa, name='kappa')
self.xi = K.variable(xi, name='xi')
self.smallConst = K.variable(smallConst, name='smallConst')
@interfaces.legacy_get_updates_support
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
large_lr = (self.lr*self.kappa)/self.smallConst
beta = (self.smallConst*self.smallConst*self.xi)/self.kappa
alpha = 1.0 - beta
zeta = self.smallConst/(self.smallConst+beta)
ms = [K.variable(K.identity(p), dtype=K.dtype(p)) for p in params]
for p, g, m in zip(params, grads, ms):
m_t = alpha * m + beta * (p - large_lr*g)
p_t = zeta*(p - self.lr * g) + (beta/(self.smallConst + beta))*m_t
self.updates.append(K.update(m, m_t))
new_p = p_t
# Apply constraints.
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
self.updates.append(K.update(p, new_p))
return self.updates
def get_config(self):
config = {'lr': float(K.get_value(self.lr)),
'kappa': float(K.get_value(self.kappa)),
'smallConst': float(K.get_value(self.smallConst)),
'xi': float(K.get_value(self.xi))}
base_config = super(AccSGD, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment