Skip to content

Instantly share code, notes, and snippets.

@redknightlois
Last active August 9, 2023 20:50
Show Gist options
  • Star 32 You must be signed in to star a gist
  • Fork 5 You must be signed in to fork a gist
  • Save redknightlois/c4023d393eb8f92bb44b2ab582d7ec20 to your computer and use it in GitHub Desktop.
Save redknightlois/c4023d393eb8f92bb44b2ab582d7ec20 to your computer and use it in GitHub Desktop.
Ralamb optimizer (RAdam + LARS trick)
class Ralamb(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
self.buffer = [[None, None, None] for ind in range(10)]
super(Ralamb, self).__init__(params, defaults)
def __setstate__(self, state):
super(Ralamb, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('Ralamb does not support sparse gradients')
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
# Decay the first and second moment running average coefficient
# m_t
exp_avg.mul_(beta1).add_(1 - beta1, grad)
# v_t
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
state['step'] += 1
buffered = self.buffer[int(state['step'] % 10)]
if state['step'] == buffered[0]:
N_sma, radam_step_size = buffered[1], buffered[2]
else:
buffered[0] = state['step']
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
buffered[1] = N_sma
# more conservative since it's an approximated value
if N_sma >= 5:
radam_step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
else:
radam_step_size = group['lr'] / (1 - beta1 ** state['step'])
buffered[2] = radam_step_size
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
# more conservative since it's an approximated value
radam_step = p_data_fp32.clone()
if N_sma >= 5:
denom = exp_avg_sq.sqrt().add_(group['eps'])
radam_step.addcdiv_(-radam_step_size, exp_avg, denom)
else:
radam_step.add_(-radam_step_size, exp_avg)
radam_norm = radam_step.pow(2).sum().sqrt()
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
if weight_norm == 0 or radam_norm == 0:
trust_ratio = 1
else:
trust_ratio = weight_norm / radam_norm
state['weight_norm'] = weight_norm
state['adam_norm'] = radam_norm
state['trust_ratio'] = trust_ratio
if N_sma >= 5:
p_data_fp32.addcdiv_(-radam_step_size * trust_ratio, exp_avg, denom)
else:
p_data_fp32.add_(-radam_step_size * trust_ratio, exp_avg)
p.data.copy_(p_data_fp32)
return loss
@redknightlois
Copy link
Author

@mgrankin Have you seen in the 80 annealing schedule such a huge back and forth in loss?

image

@Tony-Y
Copy link

Tony-Y commented Aug 28, 2019

https://github.com/ymcui/LAMB_Optimizer_TF/blob/a804c2f2995cda9a4f6b804ab445e19fc4a1036f/optimization.py#L259-L265

      # Note: Here are two choices for scaling function \phi(z)
      # minmax:   \phi(z) = min(max(z, \gamma_l), \gamma_u)
      # identity: \phi(z) = z
      # The authors does not mention what is \gamma_l and \gamma_u
      # UPDATE: after asking authors, they provide me the code below.
      # ratio = array_ops.where(math_ops.greater(w_norm, 0), array_ops.where(
      #      math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0)

The authors might not use that clipping. The code they provided is equivalent to the following code:

                if weight_norm == 0 or radam_norm == 0:
                    trust_ratio = 1
                else:
                    trust_ratio = weight_norm / radam_norm

Update:
https://github.com/tensorflow/tensorflow/blob/66bb198acad21260038805e02960b791cb467177/tensorflow/contrib/opt/python/training/lars_optimizer.py#L108-L114
TensorFlow's LARSOptimizer has a similar code:

trust_ratio = array_ops.where(
          math_ops.greater(w_norm, 0),
          array_ops.where(
              math_ops.greater(g_norm, 0),
              (self._eeta * w_norm /
               (g_norm + self._weight_decay * w_norm + self._epsilon)), 1.0),
          1.0)

Update2:
https://github.com/borisgin/nvcaffe/blob/8896d3303cfdfb575923d6bb108a99ebda728855/src/caffe/solvers/sgd_solver.cpp#L321-L326
Caffe's LARC uses the same method:

      float rate = 1.F;
      if (w_norm > 0.F && wgrad_norm > 0.F) {
        //float weight_decay = this->param_.weight_decay();
        //rate = gw_ratio * w_norm / (wgrad_norm + weight_decay * w_norm);
        rate = gw_ratio * w_norm / wgrad_norm ;
      }

I suspect the authors didn't provide a code for that clipping.

@oguiza
Copy link

oguiza commented Aug 28, 2019

This is just anecdotal, but I've used the updated version with my own dataset, and no clipping is better than clipping to 10.
Thanks for releasing this code Federico, and to all contributors for making it better!

@mgrankin
Copy link

@mgrankin Have you seen in the 80 annealing schedule such a huge back and forth in loss?

@redknightlois Haven't run 80 so far, but it's worth digging those swings.

@redknightlois
Copy link
Author

We have definitely did something wrong here... Version 3 is broken, I am not entirely sure it converges at all. Back to the drawing board. @r1ckya Any idea what might be wrong?

@frgfm
Copy link

frgfm commented Aug 28, 2019

@redknightlois yes same here
I went back and reimplemented both papers (LARS optional, and the above-mentioned clipping being optional as well) and it seems to be working for now: https://github.com/frgfm/Holocron/blob/master/holocron/optim/radam.py

@r1ckya
Copy link

r1ckya commented Aug 28, 2019

@frgfm you do not use bias correction on step size, I thought about removing that to, but didn't have time to test it yet (in the paper they ditched that too, at least last time I checked https://arxiv.org/abs/1904.00962v3 ), maybe that affects something in a weird way.

@redknightlois
Copy link
Author

redknightlois commented Aug 28, 2019

I am testing it now and got better results, the problem was that we were doing it wrong. Real wrong... will update in a few minutes

python train.py --run 20 --woof 0 --size 128 --bs 64 --mixup 0 --sa 0 --epoch 5 --lr 1e-2 --gpu 0 --opt ralamb                                                                                                               
lr: 0.01; eff_lr: 0.01; size: 128; alpha: 0.99; mom: 0.9; eps: 1e-06
\.fastai\data\imagenette-160
epoch     train_loss  valid_loss  accuracy  top_k_accuracy  time
0         1.796517    1.931205    0.484000  0.890000        02:26
1         1.509954    1.727263    0.536000  0.912000        02:11
2         1.276914    1.195526    0.734000  0.958000        02:11
3         1.128250    1.019652    0.792000  0.978000        02:13
4         1.009607    0.911938    0.834000  0.984000        02:11

@redknightlois
Copy link
Author

@frgfm Really like your implementation, far cleaner than mine. I would suggest @mgrankin to use yours instead.

@frgfm
Copy link

frgfm commented Aug 28, 2019

@r1ckya If you are referring to the bias-correction of first moment, I noticed I had an unpushed commit!
https://github.com/frgfm/Holocron/blob/master/holocron/optim/radam.py#L90-L91

For the rest, I actually stuck with the idea of the initial LARS paper https://arxiv.org/pdf/1708.03888.pdf. I'll check the difference in performance with the paper you mention

Thanks @redknightlois!

@redknightlois
Copy link
Author

@frgfm I cannot make yours to converge on ImageWoof.

@frgfm
Copy link

frgfm commented Aug 28, 2019

@redknightlois Sorry about that, the unpushed commit had an issue in it (only correcting bias for first but not second moment)

I updated it, but even with the above, I'm not training on ImageWoof but it's performing quite well. Not using weight_decay in my case, and had to scale up the learning rate but it's definitely converging!

@frgfm
Copy link

frgfm commented Aug 28, 2019

But your first implementation is still working for me.
It's just when I put the changes you both mentioned earlier, I cannot make it work, for unclear reasons

@redknightlois
Copy link
Author

That's because it is broken... I just updated to the proper one. I updated to your newer version but no luck either. If you are playing with it with lookahead, careful because it can do well even if base optimizer is crap based on what I have seen.

@redknightlois
Copy link
Author

redknightlois commented Aug 28, 2019

These are a few of the results of the current version

lr: 0.001; eff_lr: 0.001; size: 128; alpha: 0.99; mom: 0.9; eps: 1e-06
epoch     train_loss  valid_loss  accuracy  top_k_accuracy  time
0         2.117527    2.215535    0.244000  0.736000        02:48
1         1.949242    2.063208    0.320000  0.830000        02:12
2         1.745239    1.941830    0.356000  0.874000        02:12
3         1.559608    1.542116    0.524000  0.938000        02:08
4         1.452923    1.492080    0.562000  0.940000        02:10
epoch     train_loss  valid_loss  accuracy  top_k_accuracy  time
0         2.112155    2.267176    0.240000  0.724000        02:13
1         1.976614    2.071130    0.300000  0.778000        02:20
2         1.753372    1.787471    0.414000  0.880000        02:14
3         1.570473    1.574903    0.518000  0.930000        02:16
4         1.450896    1.522807    0.524000  0.940000        02:11

@mgrankin
Copy link

Great, I'll update the notebooks in repo soon.

@oguiza
Copy link

oguiza commented Aug 29, 2019

Thanks a lot redknightlois. I've tried v4 of ralamb and it works much better! I got a new SOTA on a problem I'm working on, and it's much smoother.

@redknightlois
Copy link
Author

@oguiza did you try it with lookahead of just as it is?

@oguiza
Copy link

oguiza commented Aug 29, 2019

I have also tested with lookahed, but my results are a bit worse. Please, bear in mind this is with my own dataset.

@redknightlois
Copy link
Author

Same here is as if lars is somehow interacting with lookahead. I have a few ideas that might be worth explore on that front

@frgfm
Copy link

frgfm commented Aug 29, 2019

After spending hours locating the issue on my implementation, I found out that I was wrongly accumulating bias correction of momentum. I tend to forget sometimes the mutability of some Python objects...
It was also confirmed by @r1ckya

Anyway here is the fix. So far, my tests seem to point out that it's holding its own compared to the first revision from here!

Two differences between the last revision of @redknightlois & my implementations:

  1. Minor difference: on line 72-73, it is subtle but since you add epsilon to the 2nd momentum term before multiplying it by its bias correction (in radam_step_size), you obtain:
(sqrt(exp_avg_sq) + group['eps'])) / sqrt(bias_correction2)

I stuck to the paper and used the following instead:

sqrt(exp_avg_sq / bias_correction2) + group['eps']
  1. Major difference: according to the paper, the denominator of the local_lr should be the norm of:
expected_update = adaptive_momentum + group['weight_decay'] * p.data

where adaptive momentum equals r_t * exp_avg_hat / (sqrt(exp_avg_sq_hat) + eps) if sma >4 and equals exp_avg_hat otherwise.
But according to line 70 and the following ones, you actually take the norm of:

p.data - group['lr'] * expected_update

And in your case, when sma > 4, your adaptive momentum is actually r_t * exp_avg_hat / (sqrt(exp_avg_sq_hat) + eps / sqrt(bias_correction2))

Interestingly, your revision is performing quite well, so I guess you somehow did a finding of your own! It rescales the update by the norm of the expected updated params (without LARS):

local_lr = phi(norm(p)) / norm(expected_p)

where expected_p = p - group['lr'] * expected_update instead of rescaling the update by its own norm:

local_lr = phi(norm(p)) / norm(expected_update)

@r1ckya
Copy link

r1ckya commented Aug 30, 2019

I agree with @frgfm on, v3 was very similar to other LAMB or Relamb implementations I've seen and it was looking more or less like what papers describe, but v4 is very different from that (major difference pointed out by @frgfm).

I am testing it now and got better results, the problem was that we were doing it wrong. Real wrong... will update in a few minutes

I woundering what was wrong with v3 in your opinion and how did you come up with this different trust_ratio calculation rule.

@redknightlois
Copy link
Author

redknightlois commented Aug 30, 2019

I am still waiting to get some results on the 20 epoch run, but the 5 epoch runs on Ralars, Ralamb (v4), Ralamb (v1), RAdam and Lookahead show the following:

  • Lookahead and RAdam are on top.
  • Ralamb v4
  • Ralars and Ralamb v1 are comparable

This is what we know so far:

  • V1 was flawed (which somehow got good results) and the AFAIK (@mgrankin should comment on that) is the results on the main page.
  • V2 and V3, they are even worse... These guys do not converge at all when running on isolation (which makes sense because we are not calculating the proper complete step before taking the actual step forward).
  • V4 those differences are completely unintended, they are bugs. However, if what @frgfm is saying is right; we probably should dig there because there may be something we don't know yet about the behavior.
  • I was not able to locally reproduce the results that @mgrankin repo has on the readme. (weird I know)

I have been working on a variant of Lookahead because it is as if LARS style optimizers do not play along with Lookahead (reason why Ranger beat all the others). But I am stuck, I need someone with a stronger math background to figure out the formula.

The idea is to incorporate the basic idea of LARS directly into Lookahead. Put it this way, you want to probably overshoot if the gradient has generally the same direction in all intermediate steps... therefore if you take t0, ..., tk/2, ... tk the vector from to->tk/2 and t0->tk should have a cosine similarity of 1... in that case you are probably looking into a deep dive kind of landscape... however if that cosine similarity is going to 0, we are probably in exploration mode and the buddy should stay a bit away in case the explorer trips into a hole (local minima).

If cosine similarity is 1, there is no harm into overshooting (trust ratio > 1) even though the buddy is going to be updated proportional to alpha in the fast direction. While lookahead today updates the slow but doesn't change the fast gradient, I am saying that from the point of view of the buddy, we could push the fast gradient in overshooting mode, because he believes it is safe to do so from his point of view. And if it fails, you just wasted k batches.

What do you think?

@frgfm
Copy link

frgfm commented Aug 30, 2019

Interesting thought @redknightlois
If I get your idea correctly, you want to rescale the interpolation/update of slow weights in Lookahead based on a distance metric between the gradients of intermediate fast steps? (and suggesting the distance metric should be cosine similarity)

And you remove LARS from the actual base optimizer?

I can draft a few options on the math of this and implement it, but I'll probably need something to run the training as I believe I don't have the hardware to meet the same training conditions as you guys!

@frgfm
Copy link

frgfm commented Aug 30, 2019

For the gradient, I actually can picture cases where cosine similarity of gradients wouldn't help (if that's what you meant). Since it's a wrapper, we don't control the way the base optimizer performs the fast steps. So imagine that we have an optimizer giving a lot of importance to momentum or other inertia factor and that, the gradient of fast weights at each step is very similar (cosine dist).

You may have the projection into a 2D scatter plot of fast weights that is actually a curve rather than a line (optimizer magic). In this case, we would overshoot further than the last fast weights but in a direction that would not fit the curve. I would have to test the hypothesis (perhaps it doesn't occur that often, but my intuition is that optimizers, that have updates where gradient is not the only main player, might not work well with this approach)

While implementing Lookahead, I actually had a somehow similar idea @redknightlois, but more of a model forward/backward efficiency perspective.
I didn't consider overshooting, but selecting alpha automatically based on the variance of the fast updates:
let s be the slow weights, and f1, ...., fk be the fast weights

  • I evaluate the variance of [f2 - f1, ..., fk - f(k-1)] which would characterize the consistency of the fast updates' direction.
  • I would normalize it / squeeze it into [0, 1] and use 1 - squeezed_variance as a synchronization rate (alpha)

The issue I had, is that memory-wise, I would have to store k-1 fast weights into memory to perform this. So, hopefully, less computation required, but higher memory usage (linearly with the synchronization period k). I'll check if I can avoid that memory overhead and try to implement this.

@mgrankin
Copy link

I re-run notebooks with latest Ralamb and RangerLars and updated the main page with the results.

@redknightlois
Copy link
Author

redknightlois commented Aug 30, 2019

And you remove LARS from the actual base optimizer?

Yes, you take out LARS from the base optimizer.

If I get your idea correctly, you want to rescale the interpolation/update of slow weights in Lookahead based on a distance metric between the gradients of intermediate fast steps? (and suggesting the distance metric should be cosine similarity)

Yes, I suggested cosine similarity because you can picture it in your head in 2 dimensions. As you mentioned when the 'curve' is an actual curve, the similarity of the vectors is bad, therefore you have to use the normal schedule...

In 2D this is what I have in mind.
image

I didn't consider overshooting, but selecting alpha automatically based on the variance of the fast updates:

I did a LARS style version of Lookahead which uses the 'trust_ratio' between the norm of the fast and slow weights... and at 5 epochs you don't see such a noticeable change... but havent have the time to run it further OR with Annealing schedule. Which looks like is making a lot of difference based on the results just published by @mgrankin. Quick question, those numbers already include the change to use Mish activation?

EDIT: Quick caveat... when I am saying you can overshoot, what I mean is that you update the fast weights to overshooting and the slow weights are modified in such a way that you do not move into overshooting range.

@mgrankin
Copy link

Quick question, those numbers already include the change to use Mish activation?

No, I haven't looked into Mish yet.

@VirajBagal
Copy link

@redknightlois. Hi, Thanks for the implementation. I have a question, how to save the Ralamb optimizer state_dict. There is no function for that. There is no load_state_dict function as well. Thanks

@redknightlois
Copy link
Author

No, this was a prototype that I knocked up in a few hours time. Feel free to add those and I will update it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment