Skip to content

Instantly share code, notes, and snippets.

@ethanyanjiali
Created June 6, 2019 06:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ethanyanjiali/856c6b71f1d79bcb98dca377530e1315 to your computer and use it in GitHub Desktop.
Save ethanyanjiali/856c6b71f1d79bcb98dca377530e1315 to your computer and use it in GitHub Desktop.
cyclegan_train_generator
@tf.function
def train_generator(images_a, images_b):
real_a = images_a
real_b = images_b
with tf.GradientTape() as tape:
# Use real B to generate B should be identical
identity_a2b = generator_a2b(real_b, training=True)
identity_b2a = generator_b2a(real_a, training=True)
loss_identity_a2b = calc_identity_loss(identity_a2b, real_b)
loss_identity_b2a = calc_identity_loss(identity_b2a, real_a)
# Generator A2B tries to trick Discriminator B that the generated image is B
loss_gan_gen_a2b = calc_gan_loss(discriminator_b(fake_a2b, training=True), True)
# Generator B2A tries to trick Discriminator A that the generated image is A
loss_gan_gen_b2a = calc_gan_loss(discriminator_a(fake_b2a, training=True), True)
loss_cycle_a2b2a = calc_cycle_loss(recon_b2a, real_a)
loss_cycle_b2a2b = calc_cycle_loss(recon_a2b, real_b)
# Total generator loss
loss_gen_total = loss_gan_gen_a2b + loss_gan_gen_b2a \
+ (loss_cycle_a2b2a + loss_cycle_b2a2b) * 10 \
+ (loss_identity_a2b + loss_identity_b2a) * 5
trainable_variables = generator_a2b.trainable_variables + generator_b2a.trainable_variables
gradient_gen = tape.gradient(loss_gen_total, trainable_variables)
optimizer_gen.apply_gradients(zip(gradient_gen, trainable_variables))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment