Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@merishnaSuwal
Last active October 20, 2020 10:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save merishnaSuwal/90bcb302d9179a862d24f64eb903c4bd to your computer and use it in GitHub Desktop.
Save merishnaSuwal/90bcb302d9179a862d24f64eb903c4bd to your computer and use it in GitHub Desktop.
def gan_model_loss(input_real, input_z, output_channel_dim, alpha):
"""
Get the loss for the discriminator and generator
Arguments:
---------
:param input_real: Images from the real dataset
:param input_z: Z input
:param out_channel_dim: The number of channels in the output image
---------
Returns:
A tuple of (discriminator loss, generator loss)
"""
# Build the Generator network
g_model_output = build_generator(input_z, output_channel_dim)
# Build the discriminator network
# For real inputs
real_d_model, real_d_logits = build_discriminator(input_real, alpha=alpha)
# For fake inputs (generated output from the generator model)
fake_d_model, fake_d_logits = build_discriminator(g_model_output, is_reuse=True, alpha=alpha)
# Calculate losses for each network
d_loss_real = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=real_d_logits,
labels=tf.ones_like(real_d_model)))
d_loss_fake = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_d_logits,
labels=tf.zeros_like(fake_d_model)))
# Discriminator loss is the sum of real and fake loss
d_loss = d_loss_real + d_loss_fake
# Generator loss
g_loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_d_logits,
labels=tf.ones_like(fake_d_model)))
return d_loss, g_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment