Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
def train_gan_model(epoch, batch_size, z_dim, learning_rate_D, learning_rate_G, beta1, get_batches, data_shape, data_image_mode, alpha):
"""
Train the GAN model.
Arguments:
----------
:param epoch: Number of epochs
:param batch_size: Batch Size
:param z_dim: Z dimension
:param learning_rate: Learning Rate
:param beta1: The exponential decay rate for the 1st moment in the optimizer
:param get_batches: Function to get batches
:param data_shape: Shape of the data
:param data_image_mode: The image mode to use for images ("RGB" or "L")
----------
"""
# Create our input placeholders
input_images, input_z, lr_G, lr_D = gan_model_inputs(data_shape[1:], z_dim)
# getting the discriminator and generator losses
d_loss, g_loss = gan_model_loss(input_images, input_z, data_shape[3], alpha)
# Optimizers
d_opt, g_opt = gan_model_optimizers(d_loss, g_loss, lr_D, lr_G, beta1)
i = 0
version = "firstTrain"
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# Saving the model
saver = tf.train.Saver()
num_epoch = 0
print("Starting the model training...")
# If training from saved checkpoint
if from_checkpoint == True:
saver.restore(sess, "./models/model.ckpt")
# Save the generator output
image_path = "generated_images/generated_fromckpt.PNG"
generator_output(sess, 4, input_z, data_shape[3], data_image_mode, image_path)
else:
for epoch_i in range(epoch):
num_epoch += 1
print("Training model for epoch_", epoch_i)
if num_epoch % 5 == 0:
# Save model every 5 epochs
save_path = saver.save(sess, "./models/model.ckpt")
print("Model has been saved.")
for batch_images in get_batches(batch_size):
# Random noise
batch_z = np.random.uniform(-1, 1, size=(batch_size, z_dim))
i += 1
# Run optimizers
_ = sess.run(d_opt, feed_dict={input_images: batch_images, input_z: batch_z, lr_D: learning_rate_D})
_ = sess.run(g_opt, feed_dict={input_images: batch_images, input_z: batch_z, lr_G: learning_rate_G})
# Every 5 epochs
if i % 5 == 0:
# Calculate the training loss
train_loss_d = d_loss.eval({input_z: batch_z, input_images: batch_images})
train_loss_g = g_loss.eval({input_z: batch_z})
# path to save the generated image
image_name = str(i) + "_epoch_" + str(epoch_i) + ".jpg"
img_save_path = "./generated_images/"
# Create folder if not exist
if not os.path.exists(img_save_path):
os.makedirs(img_save_path)
image_path = img_save_path + "/" + image_name
# Print the values of epoch and losses
print("Epoch {}/{}...".format(epoch_i+1, epoch),
"Discriminator Loss: {:.4f}...".format(train_loss_d),
"Generator Loss: {:.4f}".format(train_loss_g))
# Save the generator output
generator_output(sess, 4, input_z, data_shape[3], data_image_mode, image_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment