|
class DecayScheduler: |
|
# to be used for decaying learning rate or regularization coefficient or momentum, etc. |
|
|
|
def __init__(self, init_value): |
|
self.init_value = init_value |
|
|
|
def __call__(self, step): |
|
assert isinstance(step, Tensor) |
|
return self.call(step) |
|
|
|
def call(self, step): |
|
# step is a Tensor with a single scalar value |
|
# return the current value as a Tensor |
|
|
|
class Constant(DecayScheduler): |
|
def call(self, step): |
|
return self.init_value * step / step # to return a tensor |
|
|
|
class ExponentialDecay(DecayScheduler): |
|
def __init__(self, init_value, decay_steps, decay_rate, staircase=False): |
|
# store the input args |
|
|
|
def call(self, step): |
|
if staircase: |
|
s = step // decay_steps |
|
else: |
|
s = step / decay_steps |
|
ret = self.init_value * decay_rate * step / step # to make ret a tensor |
|
return ret ^ s |
|
|
|
class Optimizer: |
|
def __init__(self, lr): |
|
"""lr could be a constant scalar or a learning rate scheduler""" |
|
if type(lr) == float: |
|
self. lr = Constant(lr) |
|
elif lr is a DecayScheduler: |
|
self.lr = lr |
|
self.step_counter = Tensor((1, ), dtype=singa.int) |
|
self.step_counter.set_value(0) |
|
self.lr_value = self.lr(self.step_counter) |
|
|
|
def get_states(self) |
|
# skip DecayScheduler as it does not have persistent states |
|
return {'step_counter': self.step_counter.get_value(0)} |
|
|
|
def set_states(self, states): |
|
self.step_counter = Tensor((1, ), dtype=singa.int) |
|
self.step_counter.set_value(states['step_counter']) |
|
self.lr_value = self.lr(self.step_counter) |
|
|
|
def __call__(loss): |
|
self.call(loss) |
|
self.step() |
|
|
|
def call(self, loss): |
|
for p, g in autograd.backward(loss): |
|
# each tensor can have a name; set the name of the param tensor in compile? |
|
self.apply(p.name, p, g) |
|
|
|
def step(self): |
|
self.step_counter += 1 |
|
self.lr_value = self.lr(self.step_counter) |
|
|
|
def apply(self, param_name, param_value, param_grad): |
|
pass |
|
|
|
@deprecated #for backward compatibility |
|
def update(self, p, g): |
|
if p.name is None: |
|
p.name = id(p) |
|
self.apply(p.name, p, g) |
|
|
|
class SGD(Optimizer): |
|
def __init__(self, lr, momentum=0.0): |
|
super().__init__(self, lr) |
|
if type(momentum) == float: |
|
self. momentum= Constant(momentum) |
|
elif momentumis a DecayScheduler: |
|
self.momentum= momentum |
|
self.mom_value = self.momentum(self.step_counter) |
|
|
|
def apply(self, pname, pvalue, pgrad): |
|
pass # update the pvalue inplace using pgrad, lr_value and mom_value |
|
|
|
def step(self): |
|
super.step() |
|
self.mom_value = self.momentum(self.step_counter) |
|
|
|
def get_states(self): |
|
states = super.get_states() |
|
if self.mom_value > 0: |
|
states['moments'] = self.moments # a dict for 1st order moments tensors |
|
return states |
|
|
|
def set_states(self, states): |
|
super.set_states(states) |
|
if 'moments' in states: |
|
self.moments = states['moments'] |
|
self.mom_value = self.momentum(self.step_counter) |
|
|