Skip to content

Instantly share code, notes, and snippets.

@Eeman1113
Created May 22, 2024 20:50
Show Gist options
  • Save Eeman1113/6b9c15ae7952f2a9708e7d32e65db6d3 to your computer and use it in GitHub Desktop.
Save Eeman1113/6b9c15ae7952f2a9708e7d32e65db6d3 to your computer and use it in GitHub Desktop.
from abc import ABC
from norch.tensor import Tensor
class Optimizer(ABC):
"""
Abstract class for optimizers
"""
def __init__(self, parameters):
if isinstance(parameters, Tensor):
raise TypeError("parameters should be an iterable but got {}".format(type(parameters)))
elif isinstance(parameters, dict):
parameters = parameters.values()
self.parameters = list(parameters)
def step(self):
raise NotImplementedError
def zero_grad(self):
for module, name, parameter in self.parameters:
parameter.zero_grad()
class SGD(Optimizer):
def __init__(self, parameters, lr=1e-1, momentum=0):
super().__init__(parameters)
self.lr = lr
self.momentum = momentum
self._cache = {'velocity': [p.zeros_like() for (_, _, p) in self.parameters]}
def step(self):
for i, (module, name, _) in enumerate(self.parameters):
parameter = getattr(module, name)
velocity = self._cache['velocity'][i]
velocity = self.momentum * velocity - self.lr * parameter.grad
updated_parameter = parameter + velocity
setattr(module, name, updated_parameter)
self._cache['velocity'][i] = velocity
parameter.detach()
velocity.detach()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment