Skip to content

Instantly share code, notes, and snippets.

@ethanyanjiali
Created June 6, 2019 06:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ethanyanjiali/0638fe05afc5e3902f659a9a2ce3e803 to your computer and use it in GitHub Desktop.
Save ethanyanjiali/0638fe05afc5e3902f659a9a2ce3e803 to your computer and use it in GitHub Desktop.
cyclegan_discriminator_loss
@tf.function
def train_discriminator(images_a, images_b, fake_a2b, fake_b2a):
real_a = images_a
real_b = images_b
with tf.GradientTape() as tape:
# Discriminator A should classify real_a as A
loss_gan_dis_a_real = calc_gan_loss(discriminator_a(real_a, training=True), True)
# Discriminator A should classify generated fake_b2a as not A
loss_gan_dis_a_fake = calc_gan_loss(discriminator_a(fake_b2a, training=True), False)
# Discriminator B should classify real_b as B
loss_gan_dis_b_real = calc_gan_loss(discriminator_b(real_b, training=True), True)
# Discriminator B should classify generated fake_a2b as not B
loss_gan_dis_b_fake = calc_gan_loss(discriminator_b(fake_a2b, training=True), False)
# Total discriminator loss
loss_dis_a = (loss_gan_dis_a_real + loss_gan_dis_a_fake) * 0.5
loss_dis_b = (loss_gan_dis_b_real + loss_gan_dis_b_fake) * 0.5
loss_dis_total = loss_dis_a + loss_dis_b
trainable_variables = discriminator_a.trainable_variables + discriminator_b.trainable_variables
gradient_dis = tape.gradient(loss_dis_total, trainable_variables)
optimizer_dis.apply_gradients(zip(gradient_dis, trainable_variables))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment