Skip to content

Instantly share code, notes, and snippets.

@vivekpadia70
Last active March 30, 2020 10:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vivekpadia70/de5cb78732742323eee25d7fdb601748 to your computer and use it in GitHub Desktop.
Save vivekpadia70/de5cb78732742323eee25d7fdb601748 to your computer and use it in GitHub Desktop.
generator = create_generator()
discriminator = create_discriminator()
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = G_loss(fake_output)
disc_loss = D_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment