Last active
October 20, 2020 10:03
-
-
Save merishnaSuwal/20ddfce5438a3d005a66f235c6c239e9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train_gan_model(epoch, batch_size, z_dim, learning_rate_D, learning_rate_G, beta1, get_batches, data_shape, data_image_mode, alpha): | |
""" | |
Train the GAN model. | |
Arguments: | |
---------- | |
:param epoch: Number of epochs | |
:param batch_size: Batch Size | |
:param z_dim: Z dimension | |
:param learning_rate: Learning Rate | |
:param beta1: The exponential decay rate for the 1st moment in the optimizer | |
:param get_batches: Function to get batches | |
:param data_shape: Shape of the data | |
:param data_image_mode: The image mode to use for images ("RGB" or "L") | |
---------- | |
""" | |
# Create our input placeholders | |
input_images, input_z, lr_G, lr_D = gan_model_inputs(data_shape[1:], z_dim) | |
# getting the discriminator and generator losses | |
d_loss, g_loss = gan_model_loss(input_images, input_z, data_shape[3], alpha) | |
# Optimizers | |
d_opt, g_opt = gan_model_optimizers(d_loss, g_loss, lr_D, lr_G, beta1) | |
i = 0 | |
version = "firstTrain" | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
# Saving the model | |
saver = tf.train.Saver() | |
num_epoch = 0 | |
print("Starting the model training...") | |
# If training from saved checkpoint | |
if from_checkpoint == True: | |
saver.restore(sess, "./models/model.ckpt") | |
# Save the generator output | |
image_path = "generated_images/generated_fromckpt.PNG" | |
generator_output(sess, 4, input_z, data_shape[3], data_image_mode, image_path) | |
else: | |
for epoch_i in range(epoch): | |
num_epoch += 1 | |
print("Training model for epoch_", epoch_i) | |
if num_epoch % 5 == 0: | |
# Save model every 5 epochs | |
save_path = saver.save(sess, "./models/model.ckpt") | |
print("Model has been saved.") | |
for batch_images in get_batches(batch_size): | |
# Random noise | |
batch_z = np.random.uniform(-1, 1, size=(batch_size, z_dim)) | |
i += 1 | |
# Run optimizers | |
_ = sess.run(d_opt, feed_dict={input_images: batch_images, input_z: batch_z, lr_D: learning_rate_D}) | |
_ = sess.run(g_opt, feed_dict={input_images: batch_images, input_z: batch_z, lr_G: learning_rate_G}) | |
# Every 5 epochs | |
if i % 5 == 0: | |
# Calculate the training loss | |
train_loss_d = d_loss.eval({input_z: batch_z, input_images: batch_images}) | |
train_loss_g = g_loss.eval({input_z: batch_z}) | |
# path to save the generated image | |
image_name = str(i) + "_epoch_" + str(epoch_i) + ".jpg" | |
img_save_path = "./generated_images/" | |
# Create folder if not exist | |
if not os.path.exists(img_save_path): | |
os.makedirs(img_save_path) | |
image_path = img_save_path + "/" + image_name | |
# Print the values of epoch and losses | |
print("Epoch {}/{}...".format(epoch_i+1, epoch), | |
"Discriminator Loss: {:.4f}...".format(train_loss_d), | |
"Generator Loss: {:.4f}".format(train_loss_g)) | |
# Save the generator output | |
generator_output(sess, 4, input_z, data_shape[3], data_image_mode, image_path) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment