Skip to content

Instantly share code, notes, and snippets.

@eladshabi
Last active February 25, 2019 12:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save eladshabi/2855601703a84532996e1b19ad3c0943 to your computer and use it in GitHub Desktop.
Save eladshabi/2855601703a84532996e1b19ad3c0943 to your computer and use it in GitHub Desktop.
Mixed precision loss scaling
# source: https://github.com/hwalsuklee/tensorflow-generative-model-collections/blob/master/ACGAN.py
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'd_' in var.name]
g_vars = [var for var in t_vars if 'g_' in var.name]
q_vars = [var for var in t_vars if ('d_' in var.name) or ('c_' in var.name) or ('g_' in var.name)]
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1)
self.g_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1)
self.q_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1)
# Stap 1
scale = 128
self.loss_scale_manager_D = FixedLossScaleManager(scale)
self.loss_scale_manager_G = FixedLossScaleManager(scale)
self.loss_scale_manager_Q = FixedLossScaleManager(scale)
self.loss_scale_optimizer_D = LossScaleOptimizer(self.d_optim, self.loss_scale_manager_D)
self.loss_scale_optimizer_G = LossScaleOptimizer(self.g_optim, self.loss_scale_manager_G)
self.loss_scale_optimizer_Q = LossScaleOptimizer(self.q_optim, self.loss_scale_manager_Q)
# Step 2 + 3 (thanks to tesnorflow)
self.grads_variables_D = self.loss_scale_optimizer_D.compute_gradients(self.d_loss, d_vars)
self.grads_variables_G = self.loss_scale_optimizer_G.compute_gradients(self.g_loss, g_vars)
self.grads_variables_Q = self.loss_scale_optimizer_Q.compute_gradients(self.q_loss, q_vars)
# Gradiente processing
self.q_grads = [(g, v) for (g, v) in self.grads_variables_Q if g is not None]
self.training_step_op_D = self.loss_scale_optimizer_D.apply_gradients(self.grads_variables_D)
self.training_step_op_G = self.loss_scale_optimizer_G.apply_gradients(self.grads_variables_G)
self.training_step_op_Q = self.loss_scale_optimizer_Q.apply_gradients(self.q_grads)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment