Skip to content

Instantly share code, notes, and snippets.

@ceshine
Created December 6, 2017 01:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ceshine/8c23a986131babea3078d73cdcadbd3b to your computer and use it in GitHub Desktop.
Save ceshine/8c23a986131babea3078d73cdcadbd3b to your computer and use it in GitHub Desktop.
Alternative SGD implementation
# Reference: http://pytorch.org/docs/master/_modules/torch/optim/sgd.html#SGD
class SGD(Optimizer):
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False):
# ...
def __setstate__(self, state):
# ...
def step(self, closure=None):
# ...
for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data
if weight_decay != 0:
d_p.add_(weight_decay, p.data)
# Apply learning rate
d_p.mul_(group['lr'])
if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
buf.mul_(momentum).add_(d_p)
else:
buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(1 - dampening, d_p)
if nesterov:
d_p = d_p.add(momentum, buf)
else:
d_p = buf
p.data.add_(-1, d_p)
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment