Skip to content

Instantly share code, notes, and snippets.

Last active January 1, 2020 15:26
Show Gist options
  • Save colllin/0b146b154c4351f9a40f741a28bff1e3 to your computer and use it in GitHub Desktop.
Save colllin/0b146b154c4351f9a40f741a28bff1e3 to your computer and use it in GitHub Desktop.
PyTorch AdamW optimizer
# Based on
import torch
import math
class AdamW(torch.optim.Optimizer):
"""Implements AdamW algorithm.
It has been proposed in `Fixing Weight Decay Regularization in Adam`_.
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
.. Fixing Weight Decay Regularization in Adam:
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
defaults = dict(lr=lr, betas=betas, eps=eps,
super(AdamW, self).__init__(params, defaults)
def step(self, closure=None):
"""Performs a single optimization step.
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
grad =
if grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
# according to the paper, this penalty should come after the bias correction
# if group['weight_decay'] != 0:
# grad = grad.add(group['weight_decay'],
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
denom = exp_avg_sq.sqrt().add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
# w = w - wd * lr * w
if group['weight_decay'] != 0:['weight_decay'] * group['lr'],
# w = w - lr * w.grad, exp_avg, denom)
# w = w - wd * lr * w - lr * w.grad
# See
return loss
Copy link

colllin commented Jul 31, 2018

Original implementation from pytorch/pytorch#3740
Fixed per the AdamW description in

  • Compute weight decay before applying gradient step.
  • Multiply the weight decay by the learning rate.

Copy link

this is great! thank you! waiting for the merge into master, using this for now! nice catch in the lua impl with the copy 👍

Copy link

Hi, as the original paper shows, i.e., the green region in Algorithm 2.

it should be['weight_decay']* ScheduleMultiplier, --- (1)
rather than['weight_decay'] * group['lr'], --- (2)
Do you know why in AdamW is not multiplied by the learning rate?

Copy link

Thanks for this explanation.
Kindly can you guys check my implementation in c++?
void Adam(vector &dW1, vector &dW2, vector &dW3, vector &W1, vector &W2, vector &W3, T LR, T lambda){

T beta_1 = 0.9;
T beta_2 = 0.999;
T epsilon = 1e-8;
T m_cap, v_cap;
T step_size;
size_t k =0;

size_t wsize = dW1.size() + dW2.size() + dW3.size();
static vector <T> m_t(wsize);
static vector <T> v_t(wsize);
static size_t t = 1;
size_t n = 0;

for(auto& i: dW1){
   // i = i + W1[k];
    m_cap = (1-pow(beta_1, t));
    v_cap = (1-pow(beta_2, t));
    step_size = LR * (sqrt(v_cap) / m_cap);
    m_t[n] = beta_1*m_t[n] + (1-beta_1)*i;
    v_t[n] = beta_2*v_t[n] + (1-beta_2)*(i*i);
    W1[k] = (W1[k]- (LR*W1[k]*lambda)) - (step_size * (m_t[n]/(sqrt(v_t[n]) + epsilon)));
k = 0;
for(auto& i: dW2){
   // i = i + W2[k];
    m_cap = (1-pow(beta_1, t));
    v_cap = (1-pow(beta_2, t));
    step_size = LR * (sqrt(v_cap) / m_cap);
    m_t[n] = beta_1*m_t[n] + (1-beta_1)*i;
    v_t[n] = beta_2*v_t[n] + (1-beta_2)*(i*i);
    W2[k] = (W2[k]- (LR*W2[k]*lambda)) - (step_size * (m_t[n]/(sqrt(v_t[n]) + epsilon)));
k = 0;
for(auto& i: dW3){
    //i = i + W3[k];
    m_cap = (1-pow(beta_1, t));
    v_cap = (1-pow(beta_2, t));
    step_size = LR * (sqrt(v_cap) / m_cap);
    m_t[n] = beta_1*m_t[n] + (1-beta_1)*i;
    v_t[n] = beta_2*v_t[n] + (1-beta_2)*(i*i);
    // - group['weight_decay']).addcdiv_(-step_size, exp_avg, denom)
    W3[k] = (W3[k]- (LR*W3[k]*lambda)) - (step_size * (m_t[n]/(sqrt(v_t[n]) + epsilon)));

t += 1;


And I am modifying my Learning rate and weight decay like below after each epoch.
lr *= (1. / (1. + (0.001 * iter_num))); lambda = weight_decay*(sqrt(double(BATCH_SIZE)/double((10000*iter_num)))); iter_num ++;

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment