Skip to content

Instantly share code, notes, and snippets.

@obeshor
Last active November 13, 2020 13:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save obeshor/737f16abaadc2bec91292d2a64867d66 to your computer and use it in GitHub Desktop.
Save obeshor/737f16abaadc2bec91292d2a64867d66 to your computer and use it in GitHub Desktop.
@tf.function
def train_step(images, generator, discriminator, generator_optimizer, discriminator_optimizer):
noise = tf.random.normal([BATCH_SIZE, NOISE_DIM])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
return gen_loss, disc_loss
def train(dataset, epochs, batch_size, print_every=10, show_every=100, figsize=(5,5)):
sample_z = np.random.uniform(-1, 1, size=(72, z_size))
samples, losses = [], []
steps = 0
for epoch in range(epochs):
start = time.time()
gen_loss_list = []
disc_loss_list = []
for image_batch, y in dataset.batches(batch_size):
steps += 1
gen_loss, disc_loss = train_step(image_batch, generator, discriminator, generator_optimizer, discriminator_optimizer)
gen_loss_list.append(gen_loss)
disc_loss_list.append(disc_loss)
if steps % print_every == 0:
# At the end of each epoch, get the losses and print them out
train_loss_g = sum(gen_loss_list) / len(gen_loss_list)
train_loss_d = sum(disc_loss_list) / len(disc_loss_list)
print("Epoch {}/{}...".format(epoch+1, epochs),
"Discriminator Loss: {:.4f}...".format(train_loss_d),
"Generator Loss: {:.4f}".format(train_loss_g),
"Time: {}".format(time.time()-start))
# Save losses to view after training
losses.append((train_loss_d, train_loss_g))
if steps % show_every == 0:
# gen_samples = generator_model(sample_z, training=False),
generator.training = False
gen_samples = generator.predict(sample_z)
samples.append(gen_samples)
_ = view_samples(-1, samples, 6, 12, figsize=figsize)
plt.show()
return losses, samples
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment