Skip to content

Instantly share code, notes, and snippets.

@tims457
Created June 29, 2021 01:49
Show Gist options
  • Save tims457/33fe0286d923bc1915c24faead9bd742 to your computer and use it in GitHub Desktop.
Save tims457/33fe0286d923bc1915c24faead9bd742 to your computer and use it in GitHub Desktop.
basic GAN training loop in tensorflow
for epoch in range(epochs):
for batch, (real_x, _) in enumerate(train_data.batch(batch_size)):
# train the discriminator
for disc_steps in range(kd):
real_x = tf.convert_to_tensor(real_x)
with tf.GradientTape() as disc_tape:
# fake images using generator with noise inputs
noise = tf.convert_to_tensor(
tf.random.normal((batch_size, image_size, image_size)))
fake_x = generator(noise)
# run discriminator on fake and real data
disc_pred_fake = discriminator((fake_x+1)*127.5)
disc_pred_real = discriminator(real_x)
real_loss = (tf.keras.losses.BinaryCrossentropy()\
(tf.ones_like(disc_pred_real), disc_pred_real))
fake_loss = (tf.keras.losses.BinaryCrossentropy()\
(tf.zeros_like(disc_pred_fake), disc_pred_fake))
disc_loss = 0.5*(real_loss + fake_loss)
# update the discriminator
disc_grads = disc_tape.gradient(disc_loss,
discriminator.trainable_weights)
disc_optimizer.apply_gradients(
zip(disc_grads, discriminator.trainable_weights))
# train the generator
for gen_steps in range(kg):
with tf.GradientTape() as gen_tape:
noise = tf.convert_to_tensor(
tf.random.normal((batch_size, image_size, image_size)))
fake_x = generator(noise)
disc_pred_fake = discriminator((fake_x+1)*127.5)
gen_loss = tf.keras.losses.BinaryCrossentropy()\
(tf.ones_like(disc_pred_fake), disc_pred_fake)
gen_grads = gen_tape.gradient(gen_loss,
generator.trainable_weights)
gen_optimizer.apply_gradients(
zip(gen_grads, generator.trainable_weights))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment