Skip to content

Instantly share code, notes, and snippets.

@tmabraham

tmabraham/SAM.py Secret

Created March 11, 2021 02:26
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tmabraham/62cc1839e1dbb280cb80a79df856ec81 to your computer and use it in GitHub Desktop.
Save tmabraham/62cc1839e1dbb280cb80a79df856ec81 to your computer and use it in GitHub Desktop.
class SAM(Callback):
"Sharpness-Aware Minimization"
def __init__(self, zero_grad=True, rho=0.05, eps=1e-12, **kwargs):
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
self.state = defaultdict(dict)
store_attr()
def params(self): return self.learn.opt.all_params(with_grad=True)
def _grad_norm(self): return torch.norm(torch.stack([p.grad.norm(p=2) for p,*_ in self.params()]), p=2)
@torch.no_grad()
def first_step(self):
scale = self.rho / (self._grad_norm() + self.eps)
for p,*_ in self.params():
self.state[p]["e_w"] = e_w = p.grad * scale
p.add_(e_w) # climb to the local maximum "w + e(w)"
if self.zero_grad: self.learn.opt.zero_grad()
@torch.no_grad()
def second_step(self):
for p,*_ in self.params(): p.sub_(self.state[p]["e_w"])
def before_step(self, **kwargs):
self.first_step()
self.learn.pred = self.model(*self.xb); self.learn('after_pred')
self.loss_func(self.learn.pred, *self.yb).backward()
self.second_step()
@tmabraham
Copy link
Author

@rsomani95 I messaged you in Discord but didn't get a response.

I would love to hear more about your success with SAM. Does it allow better performance in fewer epochs?

SAM (not just in fastai but also PyTorch in general) is something that doesn't tend to play nicely with additional changes to the training loop such as mixed precision or gradient accumulation. I think in this case your solution should work, but try it out and let me know if there are any issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment