-
-
Save tmabraham/62cc1839e1dbb280cb80a79df856ec81 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SAM(Callback): | |
"Sharpness-Aware Minimization" | |
def __init__(self, zero_grad=True, rho=0.05, eps=1e-12, **kwargs): | |
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" | |
self.state = defaultdict(dict) | |
store_attr() | |
def params(self): return self.learn.opt.all_params(with_grad=True) | |
def _grad_norm(self): return torch.norm(torch.stack([p.grad.norm(p=2) for p,*_ in self.params()]), p=2) | |
@torch.no_grad() | |
def first_step(self): | |
scale = self.rho / (self._grad_norm() + self.eps) | |
for p,*_ in self.params(): | |
self.state[p]["e_w"] = e_w = p.grad * scale | |
p.add_(e_w) # climb to the local maximum "w + e(w)" | |
if self.zero_grad: self.learn.opt.zero_grad() | |
@torch.no_grad() | |
def second_step(self): | |
for p,*_ in self.params(): p.sub_(self.state[p]["e_w"]) | |
def before_step(self, **kwargs): | |
self.first_step() | |
self.learn.pred = self.model(*self.xb); self.learn('after_pred') | |
self.loss_func(self.learn.pred, *self.yb).backward() | |
self.second_step() |
@rsomani95 I messaged you in Discord but didn't get a response.
I would love to hear more about your success with SAM. Does it allow better performance in fewer epochs?
SAM (not just in fastai but also PyTorch in general) is something that doesn't tend to play nicely with additional changes to the training loop such as mixed precision or gradient accumulation. I think in this case your solution should work, but try it out and let me know if there are any issues.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hey @tmabraham, thanks for sharing this. I've been able to use this callback to much success, especially with larger models (I've been using
efficientnetv2_rw_m
fromtimm
).I'm curious what you think the best way to use gradient accumulation with SAM is? My presumption is that it should work out of the box, as long as
SAM
is called afterGradientAccumulation
.If that is indeed the case, perhaps the safest way to use the
SAM
callback would be something like this:Curious to hear your thoughts.