Skip to content

Instantly share code, notes, and snippets.

@tmabraham

tmabraham/SAM.py Secret

Created March 11, 2021 02:26
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tmabraham/62cc1839e1dbb280cb80a79df856ec81 to your computer and use it in GitHub Desktop.
Save tmabraham/62cc1839e1dbb280cb80a79df856ec81 to your computer and use it in GitHub Desktop.
class SAM(Callback):
"Sharpness-Aware Minimization"
def __init__(self, zero_grad=True, rho=0.05, eps=1e-12, **kwargs):
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
self.state = defaultdict(dict)
store_attr()
def params(self): return self.learn.opt.all_params(with_grad=True)
def _grad_norm(self): return torch.norm(torch.stack([p.grad.norm(p=2) for p,*_ in self.params()]), p=2)
@torch.no_grad()
def first_step(self):
scale = self.rho / (self._grad_norm() + self.eps)
for p,*_ in self.params():
self.state[p]["e_w"] = e_w = p.grad * scale
p.add_(e_w) # climb to the local maximum "w + e(w)"
if self.zero_grad: self.learn.opt.zero_grad()
@torch.no_grad()
def second_step(self):
for p,*_ in self.params(): p.sub_(self.state[p]["e_w"])
def before_step(self, **kwargs):
self.first_step()
self.learn.pred = self.model(*self.xb); self.learn('after_pred')
self.loss_func(self.learn.pred, *self.yb).backward()
self.second_step()
@rsomani95
Copy link

Hey @tmabraham, thanks for sharing this. I've been able to use this callback to much success, especially with larger models (I've been using efficientnetv2_rw_m from timm).

I'm curious what you think the best way to use gradient accumulation with SAM is? My presumption is that it should work out of the box, as long as SAM is called after GradientAccumulation.

If that is indeed the case, perhaps the safest way to use the SAM callback would be something like this:

from fastai.vision.all import GradientAccumulation

class SAM(Callback):
    order = GradientAccumulation.order + 1
    ....

Curious to hear your thoughts.

@tmabraham
Copy link
Author

@rsomani95 I messaged you in Discord but didn't get a response.

I would love to hear more about your success with SAM. Does it allow better performance in fewer epochs?

SAM (not just in fastai but also PyTorch in general) is something that doesn't tend to play nicely with additional changes to the training loop such as mixed precision or gradient accumulation. I think in this case your solution should work, but try it out and let me know if there are any issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment