Skip to content

Instantly share code, notes, and snippets.

@emilemathieu
Last active August 9, 2018 12:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save emilemathieu/46fb760060ea4b28bad28f3035c26a21 to your computer and use it in GitHub Desktop.
Save emilemathieu/46fb760060ea4b28bad28f3035c26a21 to your computer and use it in GitHub Desktop.
class Optimizer(object):
def __init__(self):
self.state = {}
def __call__(self, layer_id, weight_type, value, grad):
raise NotImplementedError()
class SGD(Optimizer):
def __init__(self, lr=0.1, momentum=0):
super().__init__()
self.lr = lr
self.momentum = momentum
def get_state(self, key):
return self.state[key] if self.momentum != 0 and key in self.state else 0
def __call__(self, layer_id, weight_type, value, grad):
old_v = self.get_state(str(layer_id) + weight_type)
new_v = self.lr * grad
new_v += self.momentum * old_v
self.state[key] = new_v
return value - new_v
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment