Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
def gan_model_optimizers(d_loss, g_loss, disc_lr, gen_lr, beta1):
"""
Get optimization operations
Arguments:
----------
:param d_loss: Discriminator loss Tensor
:param g_loss: Generator loss Tensor
:param disc_lr: Placeholder for Learning Rate for discriminator
:param gen_lr: Placeholder for Learning Rate for generator
:param beta1: The exponential decay rate for the 1st moment in the optimizer
----------
Returns:
A tuple of (discriminator training operation, generator training operation)
"""
# Get the trainable_variables, split into G and D parts
train_vars = tf.trainable_variables()
gen_vars = [var for var in train_vars if var.name.startswith("generator")]
disc_vars = [var for var in train_vars if var.name.startswith("discriminator")]
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
# Generator update
gen_updates = [op for op in update_ops if op.name.startswith('generator')]
# Optimizers
with tf.control_dependencies(gen_updates):
disc_train_opt = tf.train.AdamOptimizer(learning_rate = disc_lr, beta1 = beta1).minimize(d_loss, var_list = disc_vars)
gen_train_opt = tf.train.AdamOptimizer(learning_rate = gen_lr, beta1 = beta1).minimize(g_loss, var_list = gen_vars)
return disc_train_opt, gen_train_opt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment