Skip to content

Instantly share code, notes, and snippets.

View dejanbatanjac's full-sized avatar
🏠

Dejan Batanjac dejanbatanjac

🏠
View GitHub Profile
@dejanbatanjac
dejanbatanjac / ralamb.py
Created August 28, 2019 19:04 — forked from redknightlois/ralamb.py
Ralamb optimizer (RAdam + LARS trick)
class Ralamb(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
self.buffer = [[None, None, None] for ind in range(10)]
super(Ralamb, self).__init__(params, defaults)
def __setstate__(self, state):
super(Ralamb, self).__setstate__(state)