Created
May 31, 2020 08:40
-
-
Save mbijon/8af6cc5f953518b3c0ca1e7713e7ee75 to your computer and use it in GitHub Desktop.
AdaBound AMSBound for Keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
""" | |
Based on Luo et al. (2019). Adaptive Gradient Methods with Dynamic Bound of Learning Rate. In Proc. of ICLR 2019. | |
""" | |
from tensorflow import keras | |
class AdaBound(keras.optimizers.Optimizer): | |
def __init__(self, lr=0.001, beta1=0.9, beta2=0.999, final_lr=0.1, gamma=1e-3, epsilon=None, weight_decay=0, amsbound=False, **kwargs): | |
super(AdaBound, self).__init__(**kwargs) | |
with keras.backend.name_scope(self.__class__.__name__): | |
self.iterations = keras.backend.variable(0, dtype='int64', name='iterations') | |
self.lr = keras.backend.variable(lr, name='lr') | |
self.beta1 = beta1 | |
self.beta2 = beta2 | |
self.final_lr = final_lr | |
self.gamma = gamma | |
if epsilon is None: | |
epsilon = keras.backend.epsilon() | |
self.epsilon = epsilon | |
self.weight_decay = weight_decay | |
self.amsbound = amsbound | |
self.initial_lr = lr | |
@staticmethod | |
def zeros_like(p): | |
return keras.backend.zeros(keras.backend.int_shape(p), keras.backend.dtype(p)) | |
def get_updates(self, loss, params): | |
import tensorflow as tf | |
grads = self.get_gradients(loss, params) | |
self.updates = [tf.assign_add(self.iterations, 1)] | |
step = tf.cast(self.iterations, tf.float32) + 1 | |
bias_correction1 = 1. - tf.pow(self.beta1, step) | |
bias_correction2 = 1. - tf.pow(self.beta2, step) | |
step_size = self.lr * (tf.sqrt(bias_correction2) / bias_correction1) | |
final_lr = self.final_lr * self.lr / self.initial_lr | |
lower_bound = final_lr * ((self.gamma * step) / (self.gamma * step + 1)) | |
upper_bound = final_lr * ((self.gamma * step + 1) / (self.gamma * step)) | |
exp_avgs = [AdaBound.zeros_like(p) for p in params] | |
exp_avg_sqs = [AdaBound.zeros_like(p) for p in params] | |
max_exp_avg_sqs = [AdaBound.zeros_like(p) for p in params] | |
self.weights = [self.iterations] + exp_avgs + exp_avg_sqs + max_exp_avg_sqs | |
for p, g, exp_avg, exp_avg_sq, max_exp_avg_sq in zip(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs): | |
new_exp_avg = (self.beta1 * exp_avg) + (1. - self.beta1) * g | |
new_exp_avg_sq = (self.beta2 * exp_avg_sq) + (1. - self.beta2) * tf.square(g) | |
self.updates.append(tf.assign(exp_avg, new_exp_avg)) | |
self.updates.append(tf.assign(exp_avg_sq, new_exp_avg_sq)) | |
if self.amsbound: | |
new_max_exp_avg_sq = tf.maximum(max_exp_avg_sq, new_exp_avg_sq) | |
self.updates.append(tf.assign(max_exp_avg_sq, new_max_exp_avg_sq)) | |
denom = tf.sqrt(new_max_exp_avg_sq) + self.epsilon | |
else: | |
denom = tf.sqrt(new_exp_avg_sq) + self.epsilon | |
step_size_t = step_size / denom | |
# clip step size | |
step_size_t = tf.clip_by_value(step_size, lower_bound, upper_bound) | |
step_size_t = step_size_t * new_exp_avg | |
# AdaBoundW | |
if self.weight_decay > 0: | |
new_p = p * (1 - self.weight_decay) - step_size_t | |
else: | |
new_p = p - step_size_t | |
# Apply constraints. | |
if getattr(p, 'constraint', None) is not None: | |
new_p = p.constraint(new_p) | |
self.updates.append(tf.assign(p, new_p)) | |
return self.updates | |
def get_config(self): | |
config = { | |
'lr': float(keras.backend.get_value(self.lr)), | |
'beta1': self.beta1, | |
'beta2': self.beta2, | |
'final_lr': self.final_lr, | |
'gamma': self.gamma, | |
'epsilon': self.epsilon, | |
'weight_decay': self.weight_decay, | |
'amsbound': self.amsbound, | |
'initial_lr': self.initial_lr, | |
} | |
base_config = super(AdaBound, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment