Skip to content

Instantly share code, notes, and snippets.

@merishnaSuwal
Last active October 20, 2020 10:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save merishnaSuwal/60773068fd9a1b8785c985c6148b21d2 to your computer and use it in GitHub Desktop.
Save merishnaSuwal/60773068fd9a1b8785c985c6148b21d2 to your computer and use it in GitHub Desktop.
def gan_model_optimizers(d_loss, g_loss, disc_lr, gen_lr, beta1):
"""
Get optimization operations
Arguments:
----------
:param d_loss: Discriminator loss Tensor
:param g_loss: Generator loss Tensor
:param disc_lr: Placeholder for Learning Rate for discriminator
:param gen_lr: Placeholder for Learning Rate for generator
:param beta1: The exponential decay rate for the 1st moment in the optimizer
----------
Returns:
A tuple of (discriminator training operation, generator training operation)
"""
# Get the trainable_variables, split into G and D parts
train_vars = tf.trainable_variables()
gen_vars = [var for var in train_vars if var.name.startswith("generator")]
disc_vars = [var for var in train_vars if var.name.startswith("discriminator")]
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
# Generator update
gen_updates = [op for op in update_ops if op.name.startswith('generator')]
# Optimizers
with tf.control_dependencies(gen_updates):
disc_train_opt = tf.train.AdamOptimizer(learning_rate = disc_lr, beta1 = beta1).minimize(d_loss, var_list = disc_vars)
gen_train_opt = tf.train.AdamOptimizer(learning_rate = gen_lr, beta1 = beta1).minimize(g_loss, var_list = gen_vars)
return disc_train_opt, gen_train_opt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment