Skip to content

Instantly share code, notes, and snippets.

@khdlr
Created May 18, 2020 07:07
Show Gist options
  • Save khdlr/f1a14980f4cab66ed779cf7b73b44123 to your computer and use it in GitHub Desktop.
Save khdlr/f1a14980f4cab66ed779cf7b73b44123 to your computer and use it in GitHub Desktop.
import torch
from torch.optim import *
def check_optimizer(optimizer_type, modify_grad=False):
testvar = torch.ones([])
testvar.requires_grad = True
if optimizer_type is SGD:
opt = optimizer_type([testvar], 1e-3)
else:
opt = optimizer_type([testvar])
def closure():
if modify_grad:
opt.zero_grad()
if modify_grad:
testvar.backward()
return testvar
for i in range(1000):
opt.step(closure)
return testvar.item()
optimizers = [Adadelta, Adagrad, Adam, AdamW, Adamax, ASGD, LBFGS, RMSprop, Rprop, SGD]
for optimizer_type in optimizers:
opt_name = optimizer_type.__name__
print(opt_name.ljust(10), "not modifying grad:", check_optimizer(optimizer_type, False))
print(opt_name.ljust(10), " modifying grad:", check_optimizer(optimizer_type, True))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment