-
-
Save tmabraham/62cc1839e1dbb280cb80a79df856ec81 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SAM(Callback): | |
"Sharpness-Aware Minimization" | |
def __init__(self, zero_grad=True, rho=0.05, eps=1e-12, **kwargs): | |
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" | |
self.state = defaultdict(dict) | |
store_attr() | |
def params(self): return self.learn.opt.all_params(with_grad=True) | |
def _grad_norm(self): return torch.norm(torch.stack([p.grad.norm(p=2) for p,*_ in self.params()]), p=2) | |
@torch.no_grad() | |
def first_step(self): | |
scale = self.rho / (self._grad_norm() + self.eps) | |
for p,*_ in self.params(): | |
self.state[p]["e_w"] = e_w = p.grad * scale | |
p.add_(e_w) # climb to the local maximum "w + e(w)" | |
if self.zero_grad: self.learn.opt.zero_grad() | |
@torch.no_grad() | |
def second_step(self): | |
for p,*_ in self.params(): p.sub_(self.state[p]["e_w"]) | |
def before_step(self, **kwargs): | |
self.first_step() | |
self.learn.pred = self.model(*self.xb); self.learn('after_pred') | |
self.loss_func(self.learn.pred, *self.yb).backward() | |
self.second_step() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@rsomani95 I messaged you in Discord but didn't get a response.
I would love to hear more about your success with SAM. Does it allow better performance in fewer epochs?
SAM (not just in fastai but also PyTorch in general) is something that doesn't tend to play nicely with additional changes to the training loop such as mixed precision or gradient accumulation. I think in this case your solution should work, but try it out and let me know if there are any issues.