Created
July 9, 2018 04:17
-
-
Save tobyyouup/cff40aec99eea2bb5387f61e4edb9ed5 to your computer and use it in GitHub Desktop.
My modification on tensor2tensor/utils/multistep_optimizer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MultistepAdamOptimizer(tf.train.AdamOptimizer): | |
"""Adam with SGD updates every n steps with accumulated gradients.""" | |
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, | |
use_locking=False, name="Adam", n=1): | |
super(MultistepAdamOptimizer, self).__init__( | |
learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=epsilon, | |
use_locking=use_locking, name=name) | |
self._n = n # Call Adam optimizer every n batches with accumulated grads | |
self._n_t = None # n as tensor | |
tf.logging.info("MultistepAdamOptimizer with step: %d", self._n) | |
def _create_slots(self, var_list): | |
"""Create slot variables for Adam with accumulated gradients.""" | |
super(MultistepAdamOptimizer, self)._create_slots(var_list) | |
first_var = min(var_list, key=lambda x: x.name) | |
iter_var = tf.get_variable(name="iter", shape=[], dtype=tf.int32, initializer=tf.zeros_initializer, trainable=False) | |
for v in var_list: | |
self._zeros_slot(v, "grad_acc", self._name) | |
def _get_iter_variable(self): | |
tf.get_variable_scope().reuse_variables() | |
return tf.get_variable(name="iter", shape=[], dtype=tf.int32) | |
def _prepare(self): | |
super(MultistepAdamOptimizer, self)._prepare() | |
self._n_t = tf.convert_to_tensor(self._n, name="n") | |
def _apply_cond(self, apply_fn, grad, var, *args, **kwargs): | |
"""Apply conditionally if counter is zero.""" | |
grad_acc = self.get_slot(var, "grad_acc") | |
def apply_adam(grad_acc, apply_fn, grad, var, *args, **kwargs): | |
total_grad = (grad_acc + grad) / tf.cast(self._n_t, grad.dtype) | |
adam_op = apply_fn(total_grad, var, *args, **kwargs) | |
with tf.control_dependencies([adam_op]): | |
grad_acc_to_zero_op = grad_acc.assign(tf.zeros_like(grad_acc), | |
use_locking=self._use_locking) | |
return tf.group(adam_op, grad_acc_to_zero_op) | |
def accumulate_gradient(grad_acc, grad): | |
assign_op = tf.assign_add(grad_acc, grad, use_locking=self._use_locking) | |
return tf.group(assign_op) # Strip return value | |
return tf.cond( | |
tf.equal(self._get_iter_variable(), 0), | |
lambda: apply_adam(grad_acc, apply_fn, grad, var, *args, **kwargs), | |
lambda: accumulate_gradient(grad_acc, grad)) | |
def _apply_dense(self, grad, var): | |
return self._apply_cond( | |
super(MultistepAdamOptimizer, self)._apply_dense, grad, var) | |
def _resource_apply_dense(self, grad, var): | |
return self._apply_cond( | |
super(MultistepAdamOptimizer, self)._resource_apply_dense, grad, var) | |
def _apply_sparse_shared(self, grad, var, indices, scatter_add): | |
return self._apply_cond( | |
super(MultistepAdamOptimizer, self)._apply_sparse_shared, grad, var, | |
indices, scatter_add) | |
def _apply_sparse(self, grad, var): | |
# TODO(fstahlberg): Implement a sparse version | |
tf.logging.warning("MultistepAdamOptimizer does not support sparse updates") | |
dense_grad = tf.convert_to_tensor(grad) | |
return self._apply_cond( | |
super(MultistepAdamOptimizer, self)._apply_dense, dense_grad, var) | |
def _finish(self, update_ops, name_scope): | |
"""Updates beta_power variables every n batches and incrs counter.""" | |
iter_ = self._get_iter_variable() | |
self._n_t = tf.Print(self._n_t, [iter_], summarize=10000, message="iter") | |
beta1_power, beta2_power = self._get_beta_accumulators() | |
with tf.control_dependencies(update_ops): | |
with tf.colocate_with(iter_): | |
def update_beta_op(): | |
update_beta1 = beta1_power.assign( | |
beta1_power * self._beta1_t, | |
use_locking=self._use_locking) | |
update_beta2 = beta2_power.assign( | |
beta2_power * self._beta2_t, | |
use_locking=self._use_locking) | |
return tf.group(update_beta1, update_beta2) | |
maybe_update_beta = tf.cond( | |
tf.equal(iter_, 0), update_beta_op, tf.no_op) | |
with tf.control_dependencies([maybe_update_beta]): | |
update_iter = iter_.assign(tf.mod(iter_ + 1, self._n_t), | |
use_locking=self._use_locking) | |
return tf.group( | |
*update_ops + [update_iter, maybe_update_beta], name=name_scope) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment