Skip to content

Instantly share code, notes, and snippets.

@vedraiyani
Created September 25, 2019 06:22
Show Gist options
  • Save vedraiyani/c2645254f56d80a11fb1454e8c45c80e to your computer and use it in GitHub Desktop.
Save vedraiyani/c2645254f56d80a11fb1454e8c45c80e to your computer and use it in GitHub Desktop.
Generative Adversarial Network (GAN).ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Generative Adversarial Network (GAN).ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/vedraiyani/c2645254f56d80a11fb1454e8c45c80e/generative-adversarial-network-gan.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eQMUmtmRBC7h",
"colab_type": "text"
},
"source": [
"[Generative Adversarial Network (GAN)](https://www.geeksforgeeks.org/generative-adversarial-network-gan/)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "RAECJJhHBBNV",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "2237eb95-9a33-4d89-aee7-fb1003b7852f"
},
"source": [
"# importing the necessary libraries and the MNIST dataset \n",
"import tensorflow as tf \n",
"import numpy as np \n",
"import matplotlib.pyplot as plt \n",
"from tensorflow.examples.tutorials.mnist import input_data \n",
"\n",
"mnist = input_data.read_data_sets(\"MNIST_data\") \n",
"\n",
"# defining functions for the two networks. \n",
"# Both the networks have two hidden layers \n",
"# and an output layer which are densely or \n",
"# fully connected layers defining the \n",
"# Generator network function \n",
"def generator(z, reuse = None): \n",
"\twith tf.variable_scope('gen', reuse = reuse): \n",
"\t\thidden1 = tf.layers.dense(inputs = z, units = 128, \n",
"\t\t\t\t\t\t\tactivation = tf.nn.leaky_relu) \n",
"\t\t\t\t\t\t\t\n",
"\t\thidden2 = tf.layers.dense(inputs = hidden1, \n",
"\t\tunits = 128, activation = tf.nn.leaky_relu) \n",
"\t\t\t\n",
"\t\toutput = tf.layers.dense(inputs = hidden2, \n",
"\t\t\tunits = 784, activation = tf.nn.tanh) \n",
"\t\t\n",
"\t\treturn output \n",
"\n",
"# defining the Discriminator network function \n",
"def discriminator(X, reuse = None): \n",
"\twith tf.variable_scope('dis', reuse = reuse): \n",
"\t\thidden1 = tf.layers.dense(inputs = X, units = 128, \n",
"\t\t\t\t\t\t\tactivation = tf.nn.leaky_relu) \n",
"\t\t\t\t\t\t\t\n",
"\t\thidden2 = tf.layers.dense(inputs = hidden1, \n",
"\t\t\tunits = 128, activation = tf.nn.leaky_relu) \n",
"\t\t\t\t\n",
"\t\tlogits = tf.layers.dense(hidden2, units = 1) \n",
"\t\toutput = tf.sigmoid(logits) \n",
"\t\t\n",
"\t\treturn output, logits \n",
"\n",
"# creating placeholders for the outputs \n",
"tf.reset_default_graph() \n",
"\n",
"real_images = tf.placeholder(tf.float32, shape =[None, 784]) \n",
"z = tf.placeholder(tf.float32, shape =[None, 100]) \n",
"\n",
"G = generator(z) \n",
"D_output_real, D_logits_real = discriminator(real_images) \n",
"D_output_fake, D_logits_fake = discriminator(G, reuse = True) \n",
"\n",
"# defining the loss function \n",
"def loss_func(logits_in, labels_in): \n",
"\treturn tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( \n",
"\t\t\t\t\t\tlogits = logits_in, labels = labels_in)) \n",
"\n",
"# Smoothing for generalization \n",
"D_real_loss = loss_func(D_logits_real, tf.ones_like(D_logits_real)*0.9) \n",
"D_fake_loss = loss_func(D_logits_fake, tf.zeros_like(D_logits_real)) \n",
"D_loss = D_real_loss + D_fake_loss \n",
"\n",
"G_loss = loss_func(D_logits_fake, tf.ones_like(D_logits_fake)) \n",
"\n",
"# defining the learning rate, batch size, \n",
"# number of epochs and using the Adam optimizer \n",
"lr = 0.001 # learning rate \n",
"\n",
"# Do this when multiple networks \n",
"# interact with each other \n",
"\n",
"# returns all variables created(the two \n",
"# variable scopes) and makes trainable true \n",
"tvars = tf.trainable_variables() \n",
"d_vars =[var for var in tvars if 'dis' in var.name] \n",
"g_vars =[var for var in tvars if 'gen' in var.name] \n",
"\n",
"D_trainer = tf.train.AdamOptimizer(lr).minimize(D_loss, var_list = d_vars) \n",
"G_trainer = tf.train.AdamOptimizer(lr).minimize(G_loss, var_list = g_vars) \n",
"\n",
"batch_size = 100 # batch size \n",
"epochs = 500 # number of epochs. The higher the better the result \n",
"init = tf.global_variables_initializer() \n",
"\n",
"# creating a session to train the networks \n",
"samples =[] # generator examples \n",
"\n",
"with tf.Session() as sess: \n",
"\tsess.run(init) \n",
"\tfor epoch in range(epochs): \n",
"\t\tnum_batches = mnist.train.num_examples//batch_size \n",
"\t\t\n",
"\t\tfor i in range(num_batches): \n",
"\t\t\tbatch = mnist.train.next_batch(batch_size) \n",
"\t\t\tbatch_images = batch[0].reshape((batch_size, 784)) \n",
"\t\t\tbatch_images = batch_images * 2-1\n",
"\t\t\tbatch_z = np.random.uniform(-1, 1, size =(batch_size, 100)) \n",
"\t\t\t_= sess.run(D_trainer, feed_dict ={real_images:batch_images, z:batch_z}) \n",
"\t\t\t_= sess.run(G_trainer, feed_dict ={z:batch_z}) \n",
"\t\t\t\n",
"\t\tprint(\"on epoch{}\".format(epoch)) \n",
"\t\t\n",
"\t\tsample_z = np.random.uniform(-1, 1, size =(1, 100)) \n",
"\t\tgen_sample = sess.run(generator(z, reuse = True), \n",
"\t\t\t\t\t\t\t\tfeed_dict ={z:sample_z}) \n",
"\t\t\n",
"\t\tsamples.append(gen_sample) \n",
"\n",
"# result after 0th epoch \n",
"plt.imshow(samples[0].reshape(28, 28)) \n",
"\n",
"# result after 499th epoch \n",
"plt.imshow(samples[49].reshape(28, 28)) \n"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"WARNING:tensorflow:From <ipython-input-1-32ddd7b92272>:6: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please write your own downloading logic.\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:252: _internal_retry.<locals>.wrap.<locals>.wrapped_fn (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use urllib or similar directly.\n",
"Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use tf.data to implement this functionality.\n",
"Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
"Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use tf.data to implement this functionality.\n",
"Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
"Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.\n",
"Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
"Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.\n",
"Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
"WARNING:tensorflow:From <ipython-input-1-32ddd7b92272>:16: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use keras.layers.dense instead.\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Call initializer instance with the dtype argument instead of passing it to the constructor\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/nn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
"on epoch0\n",
"on epoch1\n",
"on epoch2\n",
"on epoch3\n",
"on epoch4\n",
"on epoch5\n",
"on epoch6\n",
"on epoch7\n",
"on epoch8\n",
"on epoch9\n",
"on epoch10\n",
"on epoch11\n",
"on epoch12\n",
"on epoch13\n",
"on epoch14\n",
"on epoch15\n",
"on epoch16\n",
"on epoch17\n",
"on epoch18\n",
"on epoch19\n",
"on epoch20\n",
"on epoch21\n",
"on epoch22\n",
"on epoch23\n",
"on epoch24\n",
"on epoch25\n",
"on epoch26\n",
"on epoch27\n",
"on epoch28\n",
"on epoch29\n",
"on epoch30\n",
"on epoch31\n",
"on epoch32\n",
"on epoch33\n",
"on epoch34\n",
"on epoch35\n",
"on epoch36\n",
"on epoch37\n",
"on epoch38\n",
"on epoch39\n",
"on epoch40\n",
"on epoch41\n",
"on epoch42\n",
"on epoch43\n",
"on epoch44\n",
"on epoch45\n",
"on epoch46\n",
"on epoch47\n",
"on epoch48\n",
"on epoch49\n",
"on epoch50\n",
"on epoch51\n",
"on epoch52\n",
"on epoch53\n",
"on epoch54\n",
"on epoch55\n",
"on epoch56\n",
"on epoch57\n",
"on epoch58\n",
"on epoch59\n",
"on epoch60\n",
"on epoch61\n",
"on epoch62\n",
"on epoch63\n",
"on epoch64\n",
"on epoch65\n",
"on epoch66\n",
"on epoch67\n",
"on epoch68\n",
"on epoch69\n",
"on epoch70\n",
"on epoch71\n",
"on epoch72\n",
"on epoch73\n",
"on epoch74\n",
"on epoch75\n",
"on epoch76\n",
"on epoch77\n",
"on epoch78\n",
"on epoch79\n",
"on epoch80\n",
"on epoch81\n",
"on epoch82\n",
"on epoch83\n",
"on epoch84\n",
"on epoch85\n",
"on epoch86\n",
"on epoch87\n",
"on epoch88\n",
"on epoch89\n",
"on epoch90\n",
"on epoch91\n",
"on epoch92\n",
"on epoch93\n",
"on epoch94\n",
"on epoch95\n",
"on epoch96\n",
"on epoch97\n",
"on epoch98\n",
"on epoch99\n",
"on epoch100\n",
"on epoch101\n",
"on epoch102\n",
"on epoch103\n",
"on epoch104\n",
"on epoch105\n",
"on epoch106\n",
"on epoch107\n",
"on epoch108\n",
"on epoch109\n",
"on epoch110\n",
"on epoch111\n",
"on epoch112\n",
"on epoch113\n",
"on epoch114\n",
"on epoch115\n",
"on epoch116\n",
"on epoch117\n",
"on epoch118\n",
"on epoch119\n",
"on epoch120\n",
"on epoch121\n",
"on epoch122\n",
"on epoch123\n",
"on epoch124\n",
"on epoch125\n",
"on epoch126\n",
"on epoch127\n",
"on epoch128\n",
"on epoch129\n",
"on epoch130\n",
"on epoch131\n",
"on epoch132\n",
"on epoch133\n",
"on epoch134\n",
"on epoch135\n",
"on epoch136\n",
"on epoch137\n",
"on epoch138\n",
"on epoch139\n",
"on epoch140\n",
"on epoch141\n",
"on epoch142\n",
"on epoch143\n",
"on epoch144\n",
"on epoch145\n",
"on epoch146\n",
"on epoch147\n",
"on epoch148\n",
"on epoch149\n",
"on epoch150\n",
"on epoch151\n",
"on epoch152\n",
"on epoch153\n",
"on epoch154\n",
"on epoch155\n",
"on epoch156\n",
"on epoch157\n",
"on epoch158\n",
"on epoch159\n",
"on epoch160\n",
"on epoch161\n",
"on epoch162\n",
"on epoch163\n",
"on epoch164\n",
"on epoch165\n",
"on epoch166\n",
"on epoch167\n",
"on epoch168\n",
"on epoch169\n",
"on epoch170\n",
"on epoch171\n",
"on epoch172\n",
"on epoch173\n",
"on epoch174\n",
"on epoch175\n",
"on epoch176\n",
"on epoch177\n",
"on epoch178\n",
"on epoch179\n",
"on epoch180\n",
"on epoch181\n",
"on epoch182\n",
"on epoch183\n",
"on epoch184\n",
"on epoch185\n",
"on epoch186\n",
"on epoch187\n",
"on epoch188\n",
"on epoch189\n",
"on epoch190\n",
"on epoch191\n",
"on epoch192\n",
"on epoch193\n",
"on epoch194\n",
"on epoch195\n",
"on epoch196\n",
"on epoch197\n",
"on epoch198\n",
"on epoch199\n",
"on epoch200\n",
"on epoch201\n",
"on epoch202\n",
"on epoch203\n",
"on epoch204\n",
"on epoch205\n",
"on epoch206\n",
"on epoch207\n",
"on epoch208\n",
"on epoch209\n",
"on epoch210\n",
"on epoch211\n",
"on epoch212\n",
"on epoch213\n",
"on epoch214\n",
"on epoch215\n",
"on epoch216\n",
"on epoch217\n",
"on epoch218\n",
"on epoch219\n",
"on epoch220\n",
"on epoch221\n",
"on epoch222\n",
"on epoch223\n",
"on epoch224\n",
"on epoch225\n",
"on epoch226\n",
"on epoch227\n",
"on epoch228\n",
"on epoch229\n",
"on epoch230\n",
"on epoch231\n",
"on epoch232\n",
"on epoch233\n",
"on epoch234\n",
"on epoch235\n",
"on epoch236\n",
"on epoch237\n",
"on epoch238\n",
"on epoch239\n",
"on epoch240\n",
"on epoch241\n",
"on epoch242\n",
"on epoch243\n",
"on epoch244\n",
"on epoch245\n",
"on epoch246\n",
"on epoch247\n",
"on epoch248\n",
"on epoch249\n",
"on epoch250\n",
"on epoch251\n",
"on epoch252\n",
"on epoch253\n",
"on epoch254\n",
"on epoch255\n",
"on epoch256\n",
"on epoch257\n",
"on epoch258\n",
"on epoch259\n",
"on epoch260\n",
"on epoch261\n",
"on epoch262\n",
"on epoch263\n",
"on epoch264\n",
"on epoch265\n",
"on epoch266\n",
"on epoch267\n",
"on epoch268\n",
"on epoch269\n",
"on epoch270\n",
"on epoch271\n",
"on epoch272\n",
"on epoch273\n",
"on epoch274\n",
"on epoch275\n",
"on epoch276\n",
"on epoch277\n",
"on epoch278\n",
"on epoch279\n",
"on epoch280\n",
"on epoch281\n",
"on epoch282\n",
"on epoch283\n",
"on epoch284\n",
"on epoch285\n",
"on epoch286\n",
"on epoch287\n",
"on epoch288\n",
"on epoch289\n",
"on epoch290\n",
"on epoch291\n",
"on epoch292\n",
"on epoch293\n",
"on epoch294\n",
"on epoch295\n",
"on epoch296\n",
"on epoch297\n",
"on epoch298\n",
"on epoch299\n",
"on epoch300\n",
"on epoch301\n",
"on epoch302\n",
"on epoch303\n",
"on epoch304\n",
"on epoch305\n",
"on epoch306\n",
"on epoch307\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment