Last active
October 20, 2020 10:05
-
-
Save merishnaSuwal/60773068fd9a1b8785c985c6148b21d2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def gan_model_optimizers(d_loss, g_loss, disc_lr, gen_lr, beta1): | |
""" | |
Get optimization operations | |
Arguments: | |
---------- | |
:param d_loss: Discriminator loss Tensor | |
:param g_loss: Generator loss Tensor | |
:param disc_lr: Placeholder for Learning Rate for discriminator | |
:param gen_lr: Placeholder for Learning Rate for generator | |
:param beta1: The exponential decay rate for the 1st moment in the optimizer | |
---------- | |
Returns: | |
A tuple of (discriminator training operation, generator training operation) | |
""" | |
# Get the trainable_variables, split into G and D parts | |
train_vars = tf.trainable_variables() | |
gen_vars = [var for var in train_vars if var.name.startswith("generator")] | |
disc_vars = [var for var in train_vars if var.name.startswith("discriminator")] | |
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) | |
# Generator update | |
gen_updates = [op for op in update_ops if op.name.startswith('generator')] | |
# Optimizers | |
with tf.control_dependencies(gen_updates): | |
disc_train_opt = tf.train.AdamOptimizer(learning_rate = disc_lr, beta1 = beta1).minimize(d_loss, var_list = disc_vars) | |
gen_train_opt = tf.train.AdamOptimizer(learning_rate = gen_lr, beta1 = beta1).minimize(g_loss, var_list = gen_vars) | |
return disc_train_opt, gen_train_opt |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment