Skip to content

Instantly share code, notes, and snippets.

@simeneide
Created October 21, 2020 11:36
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save simeneide/c3740bfc40cb64132f14afd0f5d0f951 to your computer and use it in GitHub Desktop.
Save simeneide/c3740bfc40cb64132f14afd0f5d0f951 to your computer and use it in GitHub Desktop.
Two simple implementations of sg-mcmc
from torch.optim.optimizer import Optimizer
class OptimizerSGLD(Optimizer):
def __init__(self, net, alpha=1e-4, sgmcmc=True):
super(OptimizerSGLD, self).__init__(net.parameters(), {})
self.net = net
self.sgmcmc = sgmcmc
self.alpha = alpha
self.noise_std = (2*self.alpha)**0.5
@torch.no_grad()
def step(self):
for name, par in self.net.named_parameters():
newpar = par - self.alpha*par.grad
if self.sgmcmc:
noise = torch.normal(torch.zeros_like(par), std=self.noise_std)
newpar += noise
par.copy_(newpar)
class OptimizerSGHMC(Optimizer):
def __init__(self, net, alpha=1e-4, nu=1.0, sgmcmc=True):
super(OptimizerSGHMC, self).__init__(net.parameters(), {})
self.net = net
self.sgmcmc = sgmcmc
self.alpha = alpha
self.nu = nu
self.noise_std = (2*self.alpha*self.nu)**0.5
self.momentum = {key : torch.zeros_like(par) for key, par in self.net.named_parameters()}
@torch.no_grad()
def step(self):
for name, par in self.net.named_parameters():
newpar = par + self.momentum[name]
par.copy_(newpar)
# Update momentum par:
self.momentum[name] = (1-self.nu)*self.momentum[name] - self.alpha*par.grad
if self.sgmcmc:
noise = torch.normal(torch.zeros_like(self.momentum[name]), std=self.noise_std)
self.momentum[name] += noise
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment