Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def discriminator(x, is_reuse=False, alpha = 0.2): | |
''' Build the discriminator network. | |
Arguments | |
--------- | |
x : Input tensor for the discriminator | |
n_units: Number of units in hidden layer | |
reuse : Reuse the variables with tf.variable_scope | |
alpha : leak parameter for leaky ReLU | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def generator(z, output_channel_dim, is_train=True): | |
''' Building the generator network. | |
Arguments | |
--------- | |
z : Input tensor for the generator | |
output_channel_dim : Shape of the generator output | |
n_units : Number of units in hidden layer | |
reuse : Reuse the variables with tf.variable_scope | |
alpha : leak parameter for leaky ReLU |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def gan_model_inputs(real_dim, z_dim): | |
""" | |
Creates the inputs for the model. | |
Arguments: | |
---------- | |
:param real_dim: tuple containing width, height and channels | |
:param z_dim: The dimension of Z | |
---------- | |
Returns: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Define the directory with real image data | |
data_dir = './data/' # Data | |
resized_data_dir = "./resized_data" # folder for saving resized data | |
# Resize images into 128x128 | |
preprocess = True # set to False if no resizing | |
if preprocess == True: | |
# Create resized folder if not exist | |
if not os.path.exists(resized_data_dir): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def gan_model_loss(input_real, input_z, output_channel_dim, alpha): | |
""" | |
Get the loss for the discriminator and generator | |
Arguments: | |
--------- | |
:param input_real: Images from the real dataset | |
:param input_z: Z input | |
:param out_channel_dim: The number of channels in the output image | |
--------- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def gan_model_optimizers(d_loss, g_loss, disc_lr, gen_lr, beta1): | |
""" | |
Get optimization operations | |
Arguments: | |
---------- | |
:param d_loss: Discriminator loss Tensor | |
:param g_loss: Generator loss Tensor | |
:param disc_lr: Placeholder for Learning Rate for discriminator | |
:param gen_lr: Placeholder for Learning Rate for generator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def generator_output(sess, n_images, input_z, output_channel_dim, image_mode, image_path): | |
""" | |
Save output from the generator. | |
Arguments: | |
---------- | |
:param sess: TensorFlow session | |
:param n_images: Number of Images to display | |
:param input_z: Input Z Tensor (noise vector) | |
:param output_channel_dim: The number of channels in the output image |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train_gan_model(epoch, batch_size, z_dim, learning_rate_D, learning_rate_G, beta1, get_batches, data_shape, data_image_mode, alpha): | |
""" | |
Train the GAN model. | |
Arguments: | |
---------- | |
:param epoch: Number of epochs | |
:param batch_size: Batch Size | |
:param z_dim: Z dimension | |
:param learning_rate: Learning Rate |
OlderNewer