Created
August 20, 2021 12:42
-
-
Save black0017/3766fc7c62bdd274df664f8ec03715a2 to your computer and use it in GitHub Desktop.
LARS_SGD_optimizer_Pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from torch.optim.optimizer import Optimizer, required | |
| import torch | |
| # almost copy paste from https://github.com/noahgolmant/pytorch-lars/blob/master/lars.py | |
| class LARS(Optimizer): | |
| r"""Implements LARS (Layer-wise Adaptive Rate Scaling). | |
| Args: | |
| params (iterable): iterable of parameters to optimize or dicts defining | |
| parameter groups | |
| lr (float): learning rate | |
| momentum (float, optional): momentum factor (default: 0) | |
| eta (float, optional): LARS coefficient as used in the paper (default: 1e-3) | |
| weight_decay (float, optional): weight decay (L2 penalty) (default: 0) | |
| dampening (float, optional): dampening for momentum (default: 0) | |
| nesterov (bool, optional): enables Nesterov momentum (default: False) | |
| epsilon (float, optional): epsilon to prevent zero division (default: 0) | |
| Example: | |
| >>> optimizer = torch.optim.LARS(model.parameters(), lr=0.1, momentum=0.9) | |
| >>> optimizer.zero_grad() | |
| >>> loss_fn(model(input), target).backward() | |
| >>> optimizer.step() | |
| """ | |
| def __init__(self, params, lr=required, momentum=0, eta=1e-3, dampening=0, | |
| weight_decay=0, nesterov=False, epsilon=0): | |
| if lr is not required and lr < 0.0: | |
| raise ValueError("Invalid learning rate: {}".format(lr)) | |
| if momentum < 0.0: | |
| raise ValueError("Invalid momentum value: {}".format(momentum)) | |
| if weight_decay < 0.0: | |
| raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) | |
| defaults = dict(lr=lr, momentum=momentum, eta=eta, dampening=dampening, | |
| weight_decay=weight_decay, nesterov=nesterov, epsilon=epsilon) | |
| if nesterov and (momentum <= 0 or dampening != 0): | |
| raise ValueError("Nesterov momentum requires a momentum and zero dampening") | |
| super(LARS, self).__init__(params, defaults) | |
| def __setstate__(self, state): | |
| super(LARS, self).__setstate__(state) | |
| for group in self.param_groups: | |
| group.setdefault('nesterov', False) | |
| def step(self, closure=None): | |
| """Performs a single optimization step. | |
| Arguments: | |
| closure (callable, optional): A closure that reevaluates the model | |
| and returns the loss. | |
| """ | |
| loss = None | |
| if closure is not None: | |
| loss = closure() | |
| for group in self.param_groups: | |
| weight_decay = group['weight_decay'] | |
| momentum = group['momentum'] | |
| eta = group['eta'] | |
| dampening = group['dampening'] | |
| nesterov = group['nesterov'] | |
| epsilon = group['epsilon'] | |
| for p in group['params']: | |
| if p.grad is None: | |
| continue | |
| w_norm = torch.norm(p.data) | |
| g_norm = torch.norm(p.grad.data) | |
| if w_norm * g_norm > 0: | |
| local_lr = eta * w_norm / (g_norm + | |
| weight_decay * w_norm + epsilon) | |
| else: | |
| local_lr = 1 | |
| d_p = p.grad.data | |
| if weight_decay != 0: | |
| d_p.add_(p.data, alpha=weight_decay) | |
| if momentum != 0: | |
| param_state = self.state[p] | |
| if 'momentum_buffer' not in param_state: | |
| buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() | |
| else: | |
| buf = param_state['momentum_buffer'] | |
| buf.mul_(momentum).add_(d_p, alpha=1 - dampening) | |
| if nesterov: | |
| d_p = d_p.add(momentum, buf) | |
| else: | |
| d_p = buf | |
| p.data.add_(d_p, alpha=-local_lr * group['lr']) | |
| return loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment