Skip to content

Instantly share code, notes, and snippets.

@seanbenhur
Last active December 27, 2023 06:12
Show Gist options
  • Save seanbenhur/55ba0f278beb6da3d5ffd7b0d38573f8 to your computer and use it in GitHub Desktop.
Save seanbenhur/55ba0f278beb6da3d5ffd7b0d38573f8 to your computer and use it in GitHub Desktop.
import torch
class SAM(torch.optim.Optimizer):
def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
defaults = dict(rho=rho, **kwargs)
super(SAM, self).__init__(params, defaults)
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
self.param_groups = self.base_optimizer.param_groups
@torch.no_grad()
def first_step(self, zero_grad=False):
grad_norm = self._grad_norm()
for group in self.param_groups:
scale = group["rho"] / (grad_norm + 1e-12)
for p in group["params"]:
if p.grad is None: continue
e_w = p.grad * scale.to(p)
p.add_(e_w) # climb to the local maximum "w + e(w)"
self.state[p]["e_w"] = e_w
if zero_grad: self.zero_grad()
@torch.no_grad()
def second_step(self, zero_grad=False):
for group in self.param_groups:
for p in group["params"]:
if p.grad is None: continue
p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)"
self.base_optimizer.step() # do the actual "sharpness-aware" update
if zero_grad: self.zero_grad()
def _grad_norm(self):
shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
norm = torch.norm(
torch.stack([
p.grad.norm(p=2).to(shared_device)
for group in self.param_groups for p in group["params"]
if p.grad is not None
]),
p=2
)
return norm
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment