Skip to content

Instantly share code, notes, and snippets.

@redknightlois
Last active August 9, 2023 20:50
Show Gist options
  • Save redknightlois/c4023d393eb8f92bb44b2ab582d7ec20 to your computer and use it in GitHub Desktop.
Save redknightlois/c4023d393eb8f92bb44b2ab582d7ec20 to your computer and use it in GitHub Desktop.
Ralamb optimizer (RAdam + LARS trick)
class Ralamb(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
self.buffer = [[None, None, None] for ind in range(10)]
super(Ralamb, self).__init__(params, defaults)
def __setstate__(self, state):
super(Ralamb, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('Ralamb does not support sparse gradients')
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
# Decay the first and second moment running average coefficient
# m_t
exp_avg.mul_(beta1).add_(1 - beta1, grad)
# v_t
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
state['step'] += 1
buffered = self.buffer[int(state['step'] % 10)]
if state['step'] == buffered[0]:
N_sma, radam_step_size = buffered[1], buffered[2]
else:
buffered[0] = state['step']
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
buffered[1] = N_sma
# more conservative since it's an approximated value
if N_sma >= 5:
radam_step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
else:
radam_step_size = group['lr'] / (1 - beta1 ** state['step'])
buffered[2] = radam_step_size
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
# more conservative since it's an approximated value
radam_step = p_data_fp32.clone()
if N_sma >= 5:
denom = exp_avg_sq.sqrt().add_(group['eps'])
radam_step.addcdiv_(-radam_step_size, exp_avg, denom)
else:
radam_step.add_(-radam_step_size, exp_avg)
radam_norm = radam_step.pow(2).sum().sqrt()
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
if weight_norm == 0 or radam_norm == 0:
trust_ratio = 1
else:
trust_ratio = weight_norm / radam_norm
state['weight_norm'] = weight_norm
state['adam_norm'] = radam_norm
state['trust_ratio'] = trust_ratio
if N_sma >= 5:
p_data_fp32.addcdiv_(-radam_step_size * trust_ratio, exp_avg, denom)
else:
p_data_fp32.add_(-radam_step_size * trust_ratio, exp_avg)
p.data.copy_(p_data_fp32)
return loss
@frgfm
Copy link

frgfm commented Aug 30, 2019

Interesting thought @redknightlois
If I get your idea correctly, you want to rescale the interpolation/update of slow weights in Lookahead based on a distance metric between the gradients of intermediate fast steps? (and suggesting the distance metric should be cosine similarity)

And you remove LARS from the actual base optimizer?

I can draft a few options on the math of this and implement it, but I'll probably need something to run the training as I believe I don't have the hardware to meet the same training conditions as you guys!

@frgfm
Copy link

frgfm commented Aug 30, 2019

For the gradient, I actually can picture cases where cosine similarity of gradients wouldn't help (if that's what you meant). Since it's a wrapper, we don't control the way the base optimizer performs the fast steps. So imagine that we have an optimizer giving a lot of importance to momentum or other inertia factor and that, the gradient of fast weights at each step is very similar (cosine dist).

You may have the projection into a 2D scatter plot of fast weights that is actually a curve rather than a line (optimizer magic). In this case, we would overshoot further than the last fast weights but in a direction that would not fit the curve. I would have to test the hypothesis (perhaps it doesn't occur that often, but my intuition is that optimizers, that have updates where gradient is not the only main player, might not work well with this approach)

While implementing Lookahead, I actually had a somehow similar idea @redknightlois, but more of a model forward/backward efficiency perspective.
I didn't consider overshooting, but selecting alpha automatically based on the variance of the fast updates:
let s be the slow weights, and f1, ...., fk be the fast weights

  • I evaluate the variance of [f2 - f1, ..., fk - f(k-1)] which would characterize the consistency of the fast updates' direction.
  • I would normalize it / squeeze it into [0, 1] and use 1 - squeezed_variance as a synchronization rate (alpha)

The issue I had, is that memory-wise, I would have to store k-1 fast weights into memory to perform this. So, hopefully, less computation required, but higher memory usage (linearly with the synchronization period k). I'll check if I can avoid that memory overhead and try to implement this.

@mgrankin
Copy link

I re-run notebooks with latest Ralamb and RangerLars and updated the main page with the results.

@redknightlois
Copy link
Author

redknightlois commented Aug 30, 2019

And you remove LARS from the actual base optimizer?

Yes, you take out LARS from the base optimizer.

If I get your idea correctly, you want to rescale the interpolation/update of slow weights in Lookahead based on a distance metric between the gradients of intermediate fast steps? (and suggesting the distance metric should be cosine similarity)

Yes, I suggested cosine similarity because you can picture it in your head in 2 dimensions. As you mentioned when the 'curve' is an actual curve, the similarity of the vectors is bad, therefore you have to use the normal schedule...

In 2D this is what I have in mind.
image

I didn't consider overshooting, but selecting alpha automatically based on the variance of the fast updates:

I did a LARS style version of Lookahead which uses the 'trust_ratio' between the norm of the fast and slow weights... and at 5 epochs you don't see such a noticeable change... but havent have the time to run it further OR with Annealing schedule. Which looks like is making a lot of difference based on the results just published by @mgrankin. Quick question, those numbers already include the change to use Mish activation?

EDIT: Quick caveat... when I am saying you can overshoot, what I mean is that you update the fast weights to overshooting and the slow weights are modified in such a way that you do not move into overshooting range.

@mgrankin
Copy link

Quick question, those numbers already include the change to use Mish activation?

No, I haven't looked into Mish yet.

@VirajBagal
Copy link

@redknightlois. Hi, Thanks for the implementation. I have a question, how to save the Ralamb optimizer state_dict. There is no function for that. There is no load_state_dict function as well. Thanks

@redknightlois
Copy link
Author

No, this was a prototype that I knocked up in a few hours time. Feel free to add those and I will update it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment