Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
Proximal Gradient Method for pytorch (minimal extension of pytorch.optim.SGD)
from torch.optim.sgd import SGD
from torch.optim.optimizer import required
class PGM(SGD):
def __init__(self, params, proxs, lr=required, momentum=0, dampening=0,
kwargs = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=0, nesterov=nesterov)
super().__init__(params, **kwargs)
if len(proxs) != len(self.param_groups):
raise ValueError("Invalid length of argument proxs: {} instead of {}".format(len(proxs), len(self.param_groups)))
for group, prox in zip(self.param_groups, list(proxs)):
group.setdefault('prox', prox)
def step(self, closure=None):
# this performs a gradient step
# optionally with momentum or nesterov acceleration
for group in self.param_groups:
prox = group['prox']
# here we apply the proximal operator to each parameter in a group
for p in group['params']: = prox(

This comment has been minimized.

Copy link
Owner Author

commented Dec 30, 2018

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.