Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
def gan_model_inputs(real_dim, z_dim):
"""
Creates the inputs for the model.
Arguments:
----------
:param real_dim: tuple containing width, height and channels
:param z_dim: The dimension of Z
----------
Returns:
Tuple of (tensor of real input images, tensor of z (noise) data, Generator learning rate, Discriminator learning rate)
"""
real_inputs = tf.placeholder(tf.float32, (None, *real_dim), name='real_inputs')
z_inputs = tf.placeholder(tf.float32, (None, z_dim), name="z_inputs")
generator_learning_rate = tf.placeholder(tf.float32, name="generator_learning_rate")
discriminator_learning_rate = tf.placeholder(tf.float32, name="discriminator_learning_rate")
return real_inputs, z_inputs, generator_learning_rate, discriminator_learning_rate
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment