Skip to content

Instantly share code, notes, and snippets.

@merishnaSuwal
Last active October 20, 2020 10:06
Show Gist options
  • Save merishnaSuwal/4758e807b9640783e00a732376795481 to your computer and use it in GitHub Desktop.
Save merishnaSuwal/4758e807b9640783e00a732376795481 to your computer and use it in GitHub Desktop.
def gan_model_inputs(real_dim, z_dim):
"""
Creates the inputs for the model.
Arguments:
----------
:param real_dim: tuple containing width, height and channels
:param z_dim: The dimension of Z
----------
Returns:
Tuple of (tensor of real input images, tensor of z (noise) data, Generator learning rate, Discriminator learning rate)
"""
real_inputs = tf.placeholder(tf.float32, (None, *real_dim), name='real_inputs')
z_inputs = tf.placeholder(tf.float32, (None, z_dim), name="z_inputs")
generator_learning_rate = tf.placeholder(tf.float32, name="generator_learning_rate")
discriminator_learning_rate = tf.placeholder(tf.float32, name="discriminator_learning_rate")
return real_inputs, z_inputs, generator_learning_rate, discriminator_learning_rate
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment