Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save amachang/a973d7d9af616c629c928b0cedc1497f to your computer and use it in GitHub Desktop.
Save amachang/a973d7d9af616c629c928b0cedc1497f to your computer and use it in GitHub Desktop.
class MomentumOptimizer(keras.optimizers.Optimizer):
# super class support [tf.Variable, LearningRateSchedule], but we only support tf.Variable for simplicity
_learning_rate: tf.Variable
momentum: float
_built: bool
velocities: List[tf.Variable]
def __init__(self, learning_rate: float, momentum: float, **kwargs):
assert 0.0 <= momentum or momentum <= 1.0
super().__init__(name="MomentumOptimizer", **kwargs)
# _build_learning_rate mainly used for difference between float and LearningRateSchedule, but currently we only support float. However for future compatibility, we use _build_learning_rate
self._learning_rate = cast(tf.Variable, self._build_learning_rate(learning_rate))
self.momentum = momentum
self._built = False
self.velocities = []
def decay_learning_rate(self, decay_rate: float) -> float:
self._learning_rate.assign(tf.multiply(self._learning_rate, decay_rate))
for v in self.velocities:
v.assign(tf.multiply(v, decay_rate))
applied_lr = self._learning_rate.read_value()
return applied_lr
def current_velocities(self) -> List[tf.Tensor]:
return [v.read_value() for v in self.velocities]
def update_velocity(self, velocities: List[tf.Tensor]):
assert len(velocities) == len(self.velocities)
for v, new_v in zip(self.velocities, velocities):
v.assign(new_v)
# override method
# var_list is model's variables
# if one of var_list is named as 'a', we will create a variable named 'velocity/a' for momentum
def build(self, var_list: List[tf.Variable]):
super().build(var_list)
if self._built:
return
for var in var_list:
self.velocities.append(self.add_variable_from_reference(model_variable=var, variable_name="velocity"))
self._built = True
def built(self) -> bool:
return self._built
# override method
def update_step(self, gradient: Union[tf.Tensor, tf.IndexedSlices], variable: tf.Variable):
assert self._built
lr = tf.cast(self.learning_rate, variable.dtype)
assert isinstance(lr, tf.Variable)
momentum = tf.cast(self.momentum, variable.dtype)
assert isinstance(momentum, tf.Tensor)
# in my understanding the m must not be None, because build method will create all known variables
v = self.velocities[self._index_dict[self._var_key(variable)]]
assert isinstance(v, tf.Variable)
add_value = calc(gradient, lambda g: tf.negative(g) * lr)
v.assign(tf.multiply(v, momentum))
assign_add(v, add_value)
variable.assign_add(v)
def get_config(self):
config = super().get_config()
config.update(
{
"learning_rate": self._serialize_hyperparameter(self._learning_rate),
"momentum": self.momentum,
"nesterov": self.nesterov,
}
)
return config
class RestoreBestWeightsAndVelocitiesCallback(keras.callbacks.Callback):
# instance individual variables
patience: int
baseline: float
# for assertion
started: bool
# model set before on_train_begin
model: Optional[keras_module.Model]
# training state
wait: int
stopped_epoch: int
best: float
best_epoch: int
best_weights: Optional[List[tf.Tensor]]
best_velocities: Optional[List[tf.Tensor]]
prev_epoch: int
prev_weights: Optional[List[tf.Tensor]]
prev_velocities: Optional[List[tf.Tensor]]
prev_best_weights: Optional[List[tf.Tensor]]
prev_best_velocities: Optional[List[tf.Tensor]]
def __init__(self, patience: int, baseline: float):
super().__init__()
self.patience = patience
self.baseline = baseline
self.started = False
self.model = None
self.wait = 0
self.best = float("inf")
self.best_epoch = -1
self.best_weights = None
self.best_velocities = None
self.prev_epoch = -1
self.prev_weights = None
self.prev_velocities = None
self.prev_best_weights = None
self.prev_best_velocities = None
def on_train_begin(self, logs: Optional[Dict[str, float]] = None):
assert not self.started
assert logs is None or isinstance(logs, dict) # suppress unused warning
assert isinstance(self.model, keras_module.Model)
assert isinstance(self.model.optimizer, MomentumOptimizer)
self.started = True
self.wait = 0
self.best = self.baseline
self.best_epoch = -1
if self.model.optimizer.built():
self.best_weights = self.model.get_weights()
self.best_velocities = self.model.optimizer.current_velocities()
else:
self.best_weights = None
self.best_velocities = None
self.prev_epoch = -1
self.prev_weights = None
self.prev_velocities = None
self.prev_best_weights = None
self.prev_best_velocities = None
def on_epoch_end(self, epoch: int, logs: Optional[Dict[str, float]] = None):
assert self.started
assert logs is not None
assert self.prev_epoch == epoch - 1
assert isinstance(self.model, keras_module.Model)
assert isinstance(self.model.optimizer, MomentumOptimizer)
current = logs.get("val_loss")
assert isinstance(current, float)
if current < self.best or self.best_weights is None:
self.best = current
self.best_epoch = epoch
self.best_weights = self.model.get_weights()
self.best_velocities = self.model.optimizer.current_velocities()
self.prev_best_weights = self.prev_weights
self.prev_best_velocities = self.prev_velocities
self.wait = 0
else:
self.wait += 1
if self.patience <= self.wait:
self.model.stop_training = True
self.prev_epoch = epoch
self.prev_weights = self.model.get_weights()
self.prev_velocities = self.model.optimizer.current_velocities()
def on_train_end(self, logs: Optional[Dict[str, float]] = None):
assert self.started
assert logs is None or isinstance(logs, dict) # suppress unused warning
assert isinstance(self.model, keras_module.Model)
assert isinstance(self.model.optimizer, MomentumOptimizer)
assert self.best_weights is not None
assert self.best_velocities is not None
self.started = False
if self.prev_best_weights is not None:
self.model.set_weights(self.prev_best_weights)
else:
self.model.set_weights(self.best_weights)
if self.prev_best_velocities is not None:
self.model.optimizer.update_velocity(self.prev_best_velocities)
else:
self.model.optimizer.update_velocity(self.best_velocities)
@tf.function
def assign_add(var: tf.Variable, value: Union[tf.Tensor, tf.IndexedSlices]):
if isinstance(value, tf.IndexedSlices):
var.scatter_add(value)
else:
var.assign_add(value)
@tf.function
def calc(value: Union[tf.Tensor, tf.IndexedSlices], fn: Callable[[tf.Tensor], tf.Tensor]) -> Union[tf.Tensor, tf.IndexedSlices]:
if isinstance(value, tf.IndexedSlices):
return tf.IndexedSlices(fn(value.values), value.indices)
else:
return fn(value)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment