Skip to content

Instantly share code, notes, and snippets.

@tobyyouup
Created July 9, 2018 04:17
Show Gist options
  • Save tobyyouup/cff40aec99eea2bb5387f61e4edb9ed5 to your computer and use it in GitHub Desktop.
Save tobyyouup/cff40aec99eea2bb5387f61e4edb9ed5 to your computer and use it in GitHub Desktop.
My modification on tensor2tensor/utils/multistep_optimizer.py
class MultistepAdamOptimizer(tf.train.AdamOptimizer):
"""Adam with SGD updates every n steps with accumulated gradients."""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
use_locking=False, name="Adam", n=1):
super(MultistepAdamOptimizer, self).__init__(
learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=epsilon,
use_locking=use_locking, name=name)
self._n = n # Call Adam optimizer every n batches with accumulated grads
self._n_t = None # n as tensor
tf.logging.info("MultistepAdamOptimizer with step: %d", self._n)
def _create_slots(self, var_list):
"""Create slot variables for Adam with accumulated gradients."""
super(MultistepAdamOptimizer, self)._create_slots(var_list)
first_var = min(var_list, key=lambda x: x.name)
iter_var = tf.get_variable(name="iter", shape=[], dtype=tf.int32, initializer=tf.zeros_initializer, trainable=False)
for v in var_list:
self._zeros_slot(v, "grad_acc", self._name)
def _get_iter_variable(self):
tf.get_variable_scope().reuse_variables()
return tf.get_variable(name="iter", shape=[], dtype=tf.int32)
def _prepare(self):
super(MultistepAdamOptimizer, self)._prepare()
self._n_t = tf.convert_to_tensor(self._n, name="n")
def _apply_cond(self, apply_fn, grad, var, *args, **kwargs):
"""Apply conditionally if counter is zero."""
grad_acc = self.get_slot(var, "grad_acc")
def apply_adam(grad_acc, apply_fn, grad, var, *args, **kwargs):
total_grad = (grad_acc + grad) / tf.cast(self._n_t, grad.dtype)
adam_op = apply_fn(total_grad, var, *args, **kwargs)
with tf.control_dependencies([adam_op]):
grad_acc_to_zero_op = grad_acc.assign(tf.zeros_like(grad_acc),
use_locking=self._use_locking)
return tf.group(adam_op, grad_acc_to_zero_op)
def accumulate_gradient(grad_acc, grad):
assign_op = tf.assign_add(grad_acc, grad, use_locking=self._use_locking)
return tf.group(assign_op) # Strip return value
return tf.cond(
tf.equal(self._get_iter_variable(), 0),
lambda: apply_adam(grad_acc, apply_fn, grad, var, *args, **kwargs),
lambda: accumulate_gradient(grad_acc, grad))
def _apply_dense(self, grad, var):
return self._apply_cond(
super(MultistepAdamOptimizer, self)._apply_dense, grad, var)
def _resource_apply_dense(self, grad, var):
return self._apply_cond(
super(MultistepAdamOptimizer, self)._resource_apply_dense, grad, var)
def _apply_sparse_shared(self, grad, var, indices, scatter_add):
return self._apply_cond(
super(MultistepAdamOptimizer, self)._apply_sparse_shared, grad, var,
indices, scatter_add)
def _apply_sparse(self, grad, var):
# TODO(fstahlberg): Implement a sparse version
tf.logging.warning("MultistepAdamOptimizer does not support sparse updates")
dense_grad = tf.convert_to_tensor(grad)
return self._apply_cond(
super(MultistepAdamOptimizer, self)._apply_dense, dense_grad, var)
def _finish(self, update_ops, name_scope):
"""Updates beta_power variables every n batches and incrs counter."""
iter_ = self._get_iter_variable()
self._n_t = tf.Print(self._n_t, [iter_], summarize=10000, message="iter")
beta1_power, beta2_power = self._get_beta_accumulators()
with tf.control_dependencies(update_ops):
with tf.colocate_with(iter_):
def update_beta_op():
update_beta1 = beta1_power.assign(
beta1_power * self._beta1_t,
use_locking=self._use_locking)
update_beta2 = beta2_power.assign(
beta2_power * self._beta2_t,
use_locking=self._use_locking)
return tf.group(update_beta1, update_beta2)
maybe_update_beta = tf.cond(
tf.equal(iter_, 0), update_beta_op, tf.no_op)
with tf.control_dependencies([maybe_update_beta]):
update_iter = iter_.assign(tf.mod(iter_ + 1, self._n_t),
use_locking=self._use_locking)
return tf.group(
*update_ops + [update_iter, maybe_update_beta], name=name_scope)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment