Skip to content

Instantly share code, notes, and snippets.

@pmelchior
Last active June 16, 2022 03:23
Show Gist options
  • Save pmelchior/bf50fb0d22d3c93fea4f3b5f1afd0f1b to your computer and use it in GitHub Desktop.
Save pmelchior/bf50fb0d22d3c93fea4f3b5f1afd0f1b to your computer and use it in GitHub Desktop.
Proximal Adam for pytorch
from torch.optim import Optimizer
import math
import torch
from torch import Tensor
from typing import List, Optional, Callable
def adaprox(params: List[Tensor],
proxes: List[Callable[[Tensor, float], Tensor]],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[int],
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float,
maximize: bool,
prox_max_iter: int,
e_rel: float):
r"""Functional API that performs Adam algorithm computation.
See :class:`~torch.optim.Adam` for details.
"""
for i, (param, prox) in enumerate(zip(params, proxes)):
# ordinary Adam
grad = grads[i] if not maximize else -grads[i]
M = exp_avgs[i]
V = exp_avg_sqs[i]
step = state_steps[i]
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
# Decay the first and second moment running average coefficient
M.mul_(beta1).add_(grad, alpha=1 - beta1)
V.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.maximum(max_exp_avg_sqs[i], V, out=max_exp_avg_sqs[i])
# Use the max. for normalizing running avg. of gradient
psi = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps)
else:
psi = (V.sqrt() / math.sqrt(bias_correction2)).add_(eps)
step_size = lr / bias_correction1
param.addcdiv_(M, psi, value=-step_size)
# proximal update(s)
alpha = lr
gamma = alpha / psi.max()
x = param.data
z = torch.clone(x)
for tau in range(prox_max_iter):
z_ = prox(z - gamma / alpha * psi * (z - x), gamma)
if torch.square(z-z_).sum() <= e_rel*e_rel * torch.square(z).sum():
break
z = z_
param.data = z_
class AdaProx(Optimizer):
r"""Implements proximal Adam algorithm.
For further details regarding the algorithm we refer to
`Proximal Adam: Robust Adaptive Update Scheme for Constrained Optimization`
(https://arxiv.org/abs/1910.10094)
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
proxes (iterable): iterable of proximal operators with signature
prox(x, gamma) -> x_
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
maximize (bool, optional): maximize the params based on the objective, instead of
minimizing (default: False)
"""
def __init__(self, params, proxes, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False, *, maximize: bool = False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad, maximize=maximize)
super(AdaProx, self).__init__(params, defaults)
# one prox per group
if len(proxes) != len(self.param_groups):
raise ValueError("Invalid length of argument proxs: {} instead of {}".format(len(proxes), len(self.param_groups)))
for group, prox in zip(self.param_groups, list(proxes)):
group.setdefault('prox', prox)
def __setstate__(self, state):
super(AdaProx, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
group.setdefault('maximize', False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
proxes_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
max_exp_avg_sqs = []
state_steps = []
beta1, beta2 = group['betas']
prox = group['prox']
for p in group['params']:
if p.grad is not None:
params_with_grad.append(p)
proxes_with_grad.append(prox)
if p.grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
grads.append(p.grad)
state = self.state[p]
# Lazy state initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
if group['amsgrad']:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
if group['amsgrad']:
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
# update the steps for each param group update
state['step'] += 1
# record the step after step update
state_steps.append(state['step'])
# regular Adam update step
adaprox(params_with_grad,
proxes_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=group['amsgrad'],
beta1=beta1,
beta2=beta2,
lr=group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'],
maximize=group['maximize'],
prox_max_iter=100,
e_rel=1e-4)
return loss
@pmelchior
Copy link
Author

Details about the proximal Adam optimizer are in this paper.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment