Last active
October 20, 2020 10:06
-
-
Save merishnaSuwal/90bcb302d9179a862d24f64eb903c4bd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def gan_model_loss(input_real, input_z, output_channel_dim, alpha): | |
""" | |
Get the loss for the discriminator and generator | |
Arguments: | |
--------- | |
:param input_real: Images from the real dataset | |
:param input_z: Z input | |
:param out_channel_dim: The number of channels in the output image | |
--------- | |
Returns: | |
A tuple of (discriminator loss, generator loss) | |
""" | |
# Build the Generator network | |
g_model_output = build_generator(input_z, output_channel_dim) | |
# Build the discriminator network | |
# For real inputs | |
real_d_model, real_d_logits = build_discriminator(input_real, alpha=alpha) | |
# For fake inputs (generated output from the generator model) | |
fake_d_model, fake_d_logits = build_discriminator(g_model_output, is_reuse=True, alpha=alpha) | |
# Calculate losses for each network | |
d_loss_real = tf.reduce_mean( | |
tf.nn.sigmoid_cross_entropy_with_logits(logits=real_d_logits, | |
labels=tf.ones_like(real_d_model))) | |
d_loss_fake = tf.reduce_mean( | |
tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_d_logits, | |
labels=tf.zeros_like(fake_d_model))) | |
# Discriminator loss is the sum of real and fake loss | |
d_loss = d_loss_real + d_loss_fake | |
# Generator loss | |
g_loss = tf.reduce_mean( | |
tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_d_logits, | |
labels=tf.ones_like(fake_d_model))) | |
return d_loss, g_loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment