Created September 25, 2019
Generative Adversarial Network (GAN).ipynb
"cell_type": "markdown",
"metadata": {
"id": "eQMUmtmRBC7h",
"colab_type": "text"
"source": [
[Generative Adversarial Network (GAN)]
"cell_type": "code",
"metadata": {
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/"
"outputId": "2237eb95-9a33-4d89-aee7-fb1003b7852f"
"source": [
"# importing the necessary libraries and the MNIST dataset \n",
"import tensorflow as tf \n",
"import numpy as np \n",
"import matplotlib.pyplot as plt \n",
"from tensorflow.examples.tutorials.mnist import input_data \n",
"mnist = input_data.read_data_sets(\"MNIST_data\") \n",
"# defining functions for the two networks. \n",
"# Both the networks have two hidden layers \n",
"# and an output layer which are densely or \n",
"# fully connected layers defining the \n",
"# Generator network function \n",
"def generator(z, reuse = None): \n",
"\twith tf.variable_scope('gen', reuse = reuse): \n",
"\t\thidden1 = tf.layers.dense(inputs = z, units = 128, \n",
"\t\t\t\t\t\t\tactivation = tf.nn.leaky_relu) \n",
"\t\thidden2 = tf.layers.dense(inputs = hidden1, \n",
"\t\tunits = 128, activation = tf.nn.leaky_relu) \n",
"\t\toutput = tf.layers.dense(inputs = hidden2, \n",
"\t\t\tunits = 784, activation = tf.nn.tanh) \n",
"\t\treturn output \n",
"# defining the Discriminator network function \n",
"def discriminator(X, reuse = None): \n",
"\twith tf.variable_scope('dis', reuse = reuse): \n",
"\t\thidden1 = tf.layers.dense(inputs = X, units = 128, \n",
"\t\t\t\t\t\t\tactivation = tf.nn.leaky_relu) \n",
"\t\thidden2 = tf.layers.dense(inputs = hidden1, \n",
"\t\t\tunits = 128, activation = tf.nn.leaky_relu) \n",
"\t\tlogits = tf.layers.dense(hidden2, units = 1) \n",
"\t\toutput = tf.sigmoid(logits) \n",
"\t\treturn output, logits \n",
"# creating placeholders for the outputs \n",
"tf.reset_default_graph() \n",
"real_images = tf.placeholder(tf.float32, shape =[None, 784]) \n",
"z = tf.placeholder(tf.float32, shape =[None, 100]) \n",
"G = generator(z) \n",
"D_output_real, D_logits_real = discriminator(real_images) \n",
"D_output_fake, D_logits_fake = discriminator(G, reuse = True) \n",
"# defining the loss function \n",
"def loss_func(logits_in, labels_in): \n",
"\treturn tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( \n",
"\t\t\t\t\t\tlogits = logits_in, labels = labels_in)) \n",
"# Smoothing for generalization \n",
"D_real_loss = loss_func(D_logits_real, tf.ones_like(D_logits_real)*0.9) \n",
"D_fake_loss = loss_func(D_logits_fake, tf.zeros_like(D_logits_real)) \n",
"D_loss = D_real_loss + D_fake_loss \n",
"G_loss = loss_func(D_logits_fake, tf.ones_like(D_logits_fake)) \n",
"# defining the learning rate, batch size, \n",
"# number of epochs and using the Adam optimizer \n",
"lr = 0.001 # learning rate \n",
"# Do this when multiple networks \n",
"# interact with each other \n",
"# returns all variables created(the two \n",
"# variable scopes) and makes trainable true \n",
"tvars = tf.trainable_variables() \n",
"d_vars =[var for var in tvars if 'dis' in] \n",
"g_vars =[var for var in tvars if 'gen' in] \n",
"D_trainer = tf.train.AdamOptimizer(lr).minimize(D_loss, var_list = d_vars) \n",
"G_trainer = tf.train.AdamOptimizer(lr).minimize(G_loss, var_list = g_vars) \n",
"batch_size = 100 # batch size \n",
"epochs = 500 # number of epochs. The higher the better the result \n",
"init = tf.global_variables_initializer() \n",
"# creating a session to train the networks \n",
"samples =[] # generator examples \n",
"with tf.Session() as sess: \n",
"\ \n",
"\tfor epoch in range(epochs): \n",
"\t\tnum_batches = mnist.train.num_examples//batch_size \n",
"\t\tfor i in range(num_batches): \n",
"\t\t\tbatch = mnist.train.next_batch(batch_size) \n",
"\t\t\tbatch_images = batch[0].reshape((batch_size, 784)) \n",
"\t\t\tbatch_images = batch_images * 2-1\n",
"\t\t\tbatch_z = np.random.uniform(-1, 1, size =(batch_size, 100)) \n",
"\t\t\t_=, feed_dict ={real_images:batch_images, z:batch_z}) \n",
"\t\t\t_=, feed_dict ={z:batch_z}) \n",
"\t\tprint(\"on epoch{}\".format(epoch)) \n",
"\t\tsample_z = np.random.uniform(-1, 1, size =(1, 100)) \n",
"\t\tgen_sample =, reuse = True), \n",
"\t\t\t\t\t\t\t\tfeed_dict ={z:sample_z}) \n",
"\t\tsamples.append(gen_sample) \n",
"# result after 0th epoch \n",
"plt.imshow(samples[0].reshape(28, 28)) \n",
"# result after 499th epoch \n",
"plt.imshow(samples[49].reshape(28, 28)) \n"
"execution_count": 0,
"outputs": [
"output_type": "stream",
"text": [
