Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Created June 4, 2021 16:56
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save crowsonkb/a5769d3396029fa749263940d3b336bd to your computer and use it in GitHub Desktop.
Save crowsonkb/a5769d3396029fa749263940d3b336bd to your computer and use it in GitHub Desktop.
Complex momentum SGD and Adam. See https://arxiv.org/abs/2102.08431.
"""Complex momentum SGD and Adam. See https://arxiv.org/abs/2102.08431."""
import math
import torch
from torch import optim
class ComplexSGD(optim.Optimizer):
def __init__(self, params, lr=1e-2, momentum=0.9, angle=math.pi / 8, weight_decay=0.):
defaults = dict(lr=lr, momentum=momentum, angle=angle, weight_decay=weight_decay)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
for group in self.param_groups:
lr = group['lr']
momentum = group['momentum']
angle = group['angle']
momentum_c = momentum * (math.cos(angle) + math.sin(angle) * 1j)
weight_decay = group['weight_decay']
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data
if momentum_c:
state = self.state[p]
if not state:
state['momentum_buffer'] = torch.zeros_like(p) + 0j
buf = state['momentum_buffer']
buf.mul_(momentum_c).add_(d_p)
d_p = buf.real
if weight_decay:
p.data.mul_(1 - (weight_decay * lr))
p.data.add_(d_p, alpha=-lr)
class ComplexAdam(optim.Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.8, 0.999), angle=math.pi / 8, eps=1e-8,
weight_decay=0.):
defaults = dict(lr=lr, betas=betas, angle=angle, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
for group in self.param_groups:
lr = group['lr']
beta_1, beta_2 = group['betas']
angle = group['angle']
beta_1_c = beta_1 * (math.cos(angle) + math.sin(angle) * 1j)
eps = group['eps']
weight_decay = group['weight_decay']
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data
state = self.state[p]
if not state:
state['decay_2'] = 1.
state['exp_avg'] = torch.zeros_like(p) + 0j
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state['exp_avg'].mul_(beta_1_c).add_(d_p)
state['exp_avg_sq'].mul_(beta_2).add_(d_p**2, alpha=1 - beta_2)
state['decay_2'] *= beta_2
exp_avg_sq_corr = state['exp_avg_sq'] / (1 - state['decay_2'])
step = state['exp_avg'].real / (exp_avg_sq_corr**0.5 + eps)
if weight_decay:
p.data.mul_(1 - (weight_decay * lr))
p.data.add_(step, alpha=-lr)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment