Skip to content

Instantly share code, notes, and snippets.

@ceceshao1
Last active January 12, 2021 10:04
Show Gist options
  • Star 15 You must be signed in to star a gist
  • Fork 6 You must be signed in to fork a gist
  • Save ceceshao1/935ea6000c8509a28130d4c55b32fcd6 to your computer and use it in GitHub Desktop.
Save ceceshao1/935ea6000c8509a28130d4c55b32fcd6 to your computer and use it in GitHub Desktop.
Generative Adversarial Networks using Keras and MNIST
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Using GANs and Keras to generate handwritten digits like MNIST"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook is adapted from Wouter Bulten's excellent tutorial (see the original blog post here: https://www.wouterbulten.nl/blog/tech/getting-started-with-generative-adversarial-networks/)\n",
"\n",
"\n",
"### The notebook covers:\n",
"1. Key imports (defining your Comet experiment)\n",
"2. Defining the discriminator and generator models\n",
"3. Creating the discriminator and generator models (and specific establishing parameters)\n",
"4. Training the GAN (and how to check the training progress)\n",
"5. Checking final results"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 1. Key imports"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, we import the comet_ml library, followed by the keras library, and others if needed. The only requirement here is that comet_ml be imported first. If you forget, just restart the kernel, and import them in the proper order"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Import Comet before your other imports \n",
"from comet_ml import Experiment "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order for comet.ml to log your experiment and results, you need to create an Experiment instance. To do this, you'll need two items:\n",
"\n",
"- a Comet api_key\n",
"- a project_name\n",
"You can find your Comet api_key when you log in to https://comet.ml and click on your project."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Click on the API key button to copy the key to your clipboard.\n",
"\n",
"It is recommended that you put your COMET_API_KEY in a .env key in the current directory. You can do that using the following code. Put it in a cell, replace the ... with your key, and then delete the cell. That way your key stays private.\n",
"\n",
"```\n",
"ipython\n",
"%%writefile .env\n",
"```\n",
"\n",
"COMET_API_KEY=...\n",
"It is also recommended that you use your project_name in the cell, so you can match the results with this code. You can make up a new name, or add this experiment to a project that already exists."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"COMET INFO: Experiment is live on comet.ml https://www.comet.ml/ceceshao1/mnist-gan/cf310adacd724bf280323e2eef92d1cd\n",
"\n"
]
}
],
"source": [
"# Establish this notebook as your Comet Experiment - set your API Key and which Comet Project and Workspace you'd like the experiment data to report to\n",
"experiment = Experiment(api_key=\"INSERT API KEY HERE\",\n",
" project_name=\"INSERT PROJECT NAME HERE\", workspace=\"INSERT WORKSPACE NAME HERE\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/anaconda/envs/py35/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
" from ._conv import register_converters as _register_converters\n",
"Using TensorFlow backend.\n"
]
}
],
"source": [
"# Keras imports \n",
"from keras.models import Sequential\n",
"from keras.layers import Dense, Activation, Flatten, Reshape\n",
"from keras.layers import Conv2D, UpSampling2D\n",
"from keras.layers import LeakyReLU, Dropout\n",
"from keras.layers import BatchNormalization\n",
"from keras.optimizers import Adam, SGD, RMSprop\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from IPython.display import clear_output, Image\n",
"\n",
"# We'll be downloading and reading in the MNIST data from Tensorflow\n",
"from tensorflow.examples.tutorials.mnist import input_data\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import keras.backend.tensorflow_backend as ktf\n",
"import tensorflow as tf\n",
"import os\n",
"\n",
"def get_session(gpu_fraction=0.45):\n",
" '''Assume that you have 6GB of GPU memory and want to allocate ~2GB'''\n",
"\n",
" num_threads = os.environ.get('OMP_NUM_THREADS')\n",
" gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction)\n",
"\n",
" if num_threads:\n",
" return tf.Session(config=tf.ConfigProto(\n",
" gpu_options=gpu_options, intra_op_parallelism_threads=num_threads))\n",
" else:\n",
" return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))\n",
"\n",
"ktf.set_session(get_session())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2(a). Defining the discriminator\n",
"\n",
"In our two-player game the discriminator takes the role of the police: given an image it has to find out whether the image is fake or not. Given this requirement, the input of our discriminator network is a (28x28x1) input patch, equal to the dimensions of an MNIST image. The output is a single node. The setup of the networks is roughly based on the [DCGAN paper](https://arxiv.org/abs/1511.06434) and one of its [implementations](https://github.com/carpedm20/DCGAN-tensorflow).\n",
"\n",
"We use `LeakyReLU` in between the convolution layers to improve the gradients. "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def discriminator():\n",
" \n",
" net = Sequential()\n",
" input_shape = (28, 28, 1)\n",
" dropout_prob = 0.4\n",
" experiment.log_parameter('dis_dropout_prob', dropout_prob)\n",
"\n",
" net.add(Conv2D(64, 5, strides=2, input_shape=input_shape, padding='same'))\n",
" net.add(LeakyReLU())\n",
" \n",
" net.add(Conv2D(128, 5, strides=2, padding='same'))\n",
" net.add(LeakyReLU())\n",
" net.add(Dropout(dropout_prob))\n",
" \n",
" net.add(Conv2D(256, 5, strides=2, padding='same'))\n",
" net.add(LeakyReLU())\n",
" net.add(Dropout(dropout_prob))\n",
" \n",
" net.add(Conv2D(512, 5, strides=1, padding='same'))\n",
" net.add(LeakyReLU())\n",
" net.add(Dropout(dropout_prob))\n",
" \n",
" net.add(Flatten())\n",
" net.add(Dense(1))\n",
" net.add(Activation('sigmoid'))\n",
" \n",
" return net"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The full network structure is as follows:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"conv2d_1 (Conv2D) (None, 14, 14, 64) 1664 \n",
"_________________________________________________________________\n",
"leaky_re_lu_1 (LeakyReLU) (None, 14, 14, 64) 0 \n",
"_________________________________________________________________\n",
"conv2d_2 (Conv2D) (None, 7, 7, 128) 204928 \n",
"_________________________________________________________________\n",
"leaky_re_lu_2 (LeakyReLU) (None, 7, 7, 128) 0 \n",
"_________________________________________________________________\n",
"dropout_1 (Dropout) (None, 7, 7, 128) 0 \n",
"_________________________________________________________________\n",
"conv2d_3 (Conv2D) (None, 4, 4, 256) 819456 \n",
"_________________________________________________________________\n",
"leaky_re_lu_3 (LeakyReLU) (None, 4, 4, 256) 0 \n",
"_________________________________________________________________\n",
"dropout_2 (Dropout) (None, 4, 4, 256) 0 \n",
"_________________________________________________________________\n",
"conv2d_4 (Conv2D) (None, 4, 4, 512) 3277312 \n",
"_________________________________________________________________\n",
"leaky_re_lu_4 (LeakyReLU) (None, 4, 4, 512) 0 \n",
"_________________________________________________________________\n",
"dropout_3 (Dropout) (None, 4, 4, 512) 0 \n",
"_________________________________________________________________\n",
"flatten_1 (Flatten) (None, 8192) 0 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 1) 8193 \n",
"_________________________________________________________________\n",
"activation_1 (Activation) (None, 1) 0 \n",
"=================================================================\n",
"Total params: 4,311,553\n",
"Trainable params: 4,311,553\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"net_discriminator = discriminator()\n",
"net_discriminator.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2(b). Defining the generator\n",
"\n",
"The task of the generator, also known as \"the counterfeiter\", is to fool the discriminator by producing real-looking fake images. These images should eventually resemble the data distribution of the MNIST dataset.\n",
"\n",
"The structure of the generator is comparable to the discrminiator but in reverse. We start with a random vector of noise (length=100) and gradually upsample. To improve the output of the generator we use `UpSampling2D` and normal convolutions instead of transposed convolutions (see also [this article](https://distill.pub/2016/deconv-checkerboard/)). The sizes of the layers are adjusted to match the size of our data (28x28 as opposed to the 64x64 of the DCGAN paper)."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def generator():\n",
" \n",
" net = Sequential()\n",
" dropout_prob = 0.4\n",
" experiment.log_parameter('adv_dropout_prob', dropout_prob)\n",
"\n",
" net.add(Dense(7*7*256, input_dim=100))\n",
" net.add(BatchNormalization(momentum=0.9))\n",
" net.add(LeakyReLU())\n",
" net.add(Reshape((7,7,256)))\n",
" net.add(Dropout(dropout_prob))\n",
" \n",
" net.add(UpSampling2D())\n",
" net.add(Conv2D(128, 5, padding='same'))\n",
" net.add(BatchNormalization(momentum=0.9))\n",
" net.add(LeakyReLU())\n",
" \n",
" net.add(UpSampling2D())\n",
" net.add(Conv2D(64, 5, padding='same'))\n",
" net.add(BatchNormalization(momentum=0.9))\n",
" net.add(LeakyReLU())\n",
" \n",
" net.add(Conv2D(32, 5, padding='same'))\n",
" net.add(BatchNormalization(momentum=0.9))\n",
" net.add(LeakyReLU())\n",
" \n",
" net.add(Conv2D(1, 5, padding='same'))\n",
" net.add(Activation('sigmoid'))\n",
" \n",
" return net"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The full network of the generator looks as follows:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"dense_2 (Dense) (None, 12544) 1266944 \n",
"_________________________________________________________________\n",
"batch_normalization_1 (Batch (None, 12544) 50176 \n",
"_________________________________________________________________\n",
"leaky_re_lu_5 (LeakyReLU) (None, 12544) 0 \n",
"_________________________________________________________________\n",
"reshape_1 (Reshape) (None, 7, 7, 256) 0 \n",
"_________________________________________________________________\n",
"dropout_4 (Dropout) (None, 7, 7, 256) 0 \n",
"_________________________________________________________________\n",
"up_sampling2d_1 (UpSampling2 (None, 14, 14, 256) 0 \n",
"_________________________________________________________________\n",
"conv2d_5 (Conv2D) (None, 14, 14, 128) 819328 \n",
"_________________________________________________________________\n",
"batch_normalization_2 (Batch (None, 14, 14, 128) 512 \n",
"_________________________________________________________________\n",
"leaky_re_lu_6 (LeakyReLU) (None, 14, 14, 128) 0 \n",
"_________________________________________________________________\n",
"up_sampling2d_2 (UpSampling2 (None, 28, 28, 128) 0 \n",
"_________________________________________________________________\n",
"conv2d_6 (Conv2D) (None, 28, 28, 64) 204864 \n",
"_________________________________________________________________\n",
"batch_normalization_3 (Batch (None, 28, 28, 64) 256 \n",
"_________________________________________________________________\n",
"leaky_re_lu_7 (LeakyReLU) (None, 28, 28, 64) 0 \n",
"_________________________________________________________________\n",
"conv2d_7 (Conv2D) (None, 28, 28, 32) 51232 \n",
"_________________________________________________________________\n",
"batch_normalization_4 (Batch (None, 28, 28, 32) 128 \n",
"_________________________________________________________________\n",
"leaky_re_lu_8 (LeakyReLU) (None, 28, 28, 32) 0 \n",
"_________________________________________________________________\n",
"conv2d_8 (Conv2D) (None, 28, 28, 1) 801 \n",
"_________________________________________________________________\n",
"activation_2 (Activation) (None, 28, 28, 1) 0 \n",
"=================================================================\n",
"Total params: 2,394,241\n",
"Trainable params: 2,368,705\n",
"Non-trainable params: 25,536\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"net_generator = generator()\n",
"net_generator.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3. Creating the models\n",
"\n",
"We now defined the two separate networks but these still need to be combined in to two trainable models: one to train the discrmininator and one to train the generator. We first start with the most simple one which is the discriminator model.\n",
"\n",
"For the discriminator model we only have to define the optimizer, all the other parts of the model are already defined. We use `SGD` as the optimizer with a low learning rate and clip the values between -1 and 1. A small decay in the learning rate can help with stabilizing. Besides the loss we also tell Keras to gives us the accuracy as a metric."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"sequential_1 (Sequential) (None, 1) 4311553 \n",
"=================================================================\n",
"Total params: 4,311,553\n",
"Trainable params: 4,311,553\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"dis_lr=0.0008\n",
"dis_clipvalue=1.0\n",
"dis_decay=1e-10\n",
"dis_optimizer = RMSprop\n",
"\n",
"optim_discriminator = RMSprop(lr=dis_lr, clipvalue=dis_clipvalue, decay=dis_decay)\n",
"model_discriminator = Sequential()\n",
"model_discriminator.add(net_discriminator)\n",
"model_discriminator.compile(loss='binary_crossentropy', optimizer=optim_discriminator, metrics=['accuracy'])\n",
"\n",
"model_discriminator.summary()\n",
"\n",
"experiment.log_parameter('dis_lr',dis_lr)\n",
"experiment.log_parameter('dis_clipvalue',dis_clipvalue)\n",
"experiment.log_parameter('dis_decay',dis_decay)\n",
"experiment.log_parameter('dis_optimizer', dis_optimizer)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The model for the generator is a bit more complex. The generator needs to fool the discriminator by generating images. To train the generator we need to assess its performance on the output of the discriminator. For this we add both networks to a combined model: *the adversarial model*. Our adverserial model uses random noise as its input and outputs the eventual prediction of the discriminator on the generated images. \n",
"\n",
"The generator performs well if the adverserial model outputs 'real' on all inputs. In other words, for any input of the adversial network aim to get an output classifying the generated image as real. This means, however, that the discriminator failed (which is a good thing for the generator). If we would use normal back propagation here on the full adversarial model we would update slowly push the discriminator to update itself and start classifying fake images as real. To prevent this we must freeze the part of the model that belongs to the discriminator.\n",
"\n",
"In Keras freezing a model is easily done by freezing all the layers of the model. By setting the `trainable` parameter to `False` we prevent the layer of updating within this particular model (it is still trainable in the discriminator model).\n",
"\n",
"The adversarial model uses `Adam` as the optimizer with the default values for the momentum."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"sequential_2 (Sequential) (None, 28, 28, 1) 2394241 \n",
"_________________________________________________________________\n",
"sequential_1 (Sequential) (None, 1) 4311553 \n",
"=================================================================\n",
"Total params: 6,705,794\n",
"Trainable params: 2,368,705\n",
"Non-trainable params: 4,337,089\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"adv_lr=0.0004\n",
"adv_clipvalue=1.0\n",
"adv_decay=1e-10\n",
"adv_optimizer = Adam\n",
"\n",
"optim_adversarial = Adam(lr=adv_lr, clipvalue=adv_clipvalue, decay=adv_decay)\n",
"model_adversarial = Sequential()\n",
"model_adversarial.add(net_generator)\n",
"\n",
"# Disable layers in discriminator\n",
"for layer in net_discriminator.layers:\n",
" layer.trainable = False\n",
"\n",
"model_adversarial.add(net_discriminator)\n",
"model_adversarial.compile(loss='binary_crossentropy', optimizer=optim_adversarial, metrics=['accuracy'])\n",
"model_adversarial.summary()\n",
"\n",
"experiment.log_parameter('adv_lr',adv_lr)\n",
"experiment.log_parameter('adv_clipvalue',adv_clipvalue)\n",
"experiment.log_parameter('adv_decay',adv_decay)\n",
"experiment.log_parameter('adv_optimizer',adv_optimizer)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that the number of non-trainable parameters is very high. This is exactly what we want! \n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Reading MNIST data\n",
"\n",
"We can now read our training data. For this I use a small utility function from Tensorflow."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting mnist/train-images-idx3-ubyte.gz\n",
"Extracting mnist/train-labels-idx1-ubyte.gz\n",
"Extracting mnist/t10k-images-idx3-ubyte.gz\n",
"Extracting mnist/t10k-labels-idx1-ubyte.gz\n"
]
}
],
"source": [
"# Read MNIST data\n",
"x_train = input_data.read_data_sets(\"mnist\", one_hot=True).train.images\n",
"x_train = x_train.reshape(-1, 28, 28, 1).astype(np.float32)\n",
"experiment.log_dataset_info(x_train)\n",
"\n",
"# Map the images to a new range [-1, 1]\n",
"#x_train = x_train / 0.5 - 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 4. Training the GAN\n",
"\n",
"With our models defined and the data loaded we can start training our GAN. The models are trained one after another, starting with the discriminator. The discriminator is trained on a data set of both fake and real images and tries to classify them correctly. The adversarial model is trained on noise vectors as explained above.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# make the directory for your generator's outputs\n",
"import os\n",
"os.makedirs(\"output/mnist-normal\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/anaconda/envs/py35/lib/python3.5/site-packages/keras/engine/training.py:479: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?\n",
" 'Discrepancy between trainable weights and collected trainable'\n"
]
}
],
"source": [
"batch_size = 256\n",
"experiment.log_parameter('batch_size', batch_size)\n",
"\n",
"vis_noise = np.random.uniform(-1.0, 1.0, size=[16, 100])\n",
"\n",
"loss_adv = []\n",
"loss_dis = []\n",
"acc_adv = []\n",
"acc_dis = []\n",
"plot_iteration = []\n",
"\n",
"for i in range(10001):\n",
" \n",
" # Select a random set of training images from the mnist dataset\n",
" images_train = x_train[np.random.randint(0, x_train.shape[0], size=batch_size), :, :, :]\n",
" # Generate a random noise vector\n",
" noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])\n",
" # Use the generator to create fake images from the noise vector\n",
" images_fake = net_generator.predict(noise)\n",
" \n",
" # Create a dataset with fake and real images\n",
" x = np.concatenate((images_train, images_fake))\n",
" y = np.ones([2*batch_size, 1])\n",
" y[batch_size:, :] = 0 \n",
"\n",
" # Train discriminator for one batch\n",
" d_stats = model_discriminator.train_on_batch(x, y)\n",
" \n",
" # Train the generator\n",
" # The input of the adversarial model is a list of noise vectors. The generator is 'good' if the discriminator classifies\n",
" # all the generated images as real. Therefore, the desired output is a list of all ones.\n",
" y = np.ones([batch_size, 1])\n",
" noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])\n",
" a_stats = model_adversarial.train_on_batch(noise, y)\n",
" \n",
" if i % 50 == 0:\n",
" experiment.log_metrics({\"loss_adv\":a_stats[0], \"loss_dis\":d_stats[0], \"acc_adv\":a_stats[1],\"acc_dis\":d_stats[1]},step=i)\n",
" \n",
" if i % 500 == 0:\n",
" # Visualize the performance of the generator by producing images from the test vector\n",
" images = net_generator.predict(vis_noise)\n",
" # Map back to original range\n",
" #images = (images + 1 ) * 0.5\n",
" plt.figure(figsize=(10,10))\n",
" \n",
" for im in range(images.shape[0]):\n",
" plt.subplot(4, 4, im+1)\n",
" image = images[im, :, :, :]\n",
" image = np.reshape(image, [28, 28])\n",
" \n",
" plt.imshow(image, cmap='gray')\n",
" plt.axis('off')\n",
" \n",
" plt.tight_layout()\n",
" # plt.savefig('/home/ubuntu/cecelia/deeplearning-resources/output/mnist-normal/{}.png'.format(i))\n",
" plt.savefig(r'output/mnist-normal/{}.png'.format(i))\n",
" experiment.log_image(r'output/mnist-normal/{}.png'.format(i))\n",
" plt.close('all')\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'api': 'https://www.comet.ml/api/rest/v1/image/get-image?imageId=05fd9cec8aa04872946ca59023a8bda4&experimentKey=cf310adacd724bf280323e2eef92d1cd',\n",
" 'web': 'https://www.comet.ml/api/image/download?imageId=05fd9cec8aa04872946ca59023a8bda4&experimentKey=cf310adacd724bf280323e2eef92d1cd'}"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# create a gif from the generator's saved output (in that output/mnist-normal directory)\n",
"\n",
"import imageio\n",
"\n",
"filenames = [r'output/mnist-normal/{}.png'.format(i * 500) for i in range(20)]\n",
"images = []\n",
"for filename in filenames:\n",
" images.append(imageio.imread(filename))\n",
"imageio.mimsave(r'output/mnist-normal/learning.gif', images, duration=0.5)\n",
"\n",
"Image(url='output/mnist-normal/learning.gif') \n",
"experiment.log_image('output/mnist-normal/learning.gif')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Morphing instances \n",
"Tuning the noise vector to see a digit morph into another. The process of tuning slowly convert a noise vector filled with zeros to one filled with ones"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1080x288 with 10 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(15,4))\n",
"\n",
"for i in range(10):\n",
" noise = np.zeros([1,100]) - 1 + (i * 0.2)\n",
" images = net_generator.predict(noise)\n",
" \n",
" image = images[0, :, :, :]\n",
" image = np.reshape(image, [28, 28])\n",
" \n",
" plt.subplot(1, 10, i+1)\n",
" plt.imshow(image, cmap='gray')\n",
" plt.axis('off')\n",
"\n",
"plt.tight_layout()\n",
"plt.savefig(r'output/mnist-normal/morph_example.png'.format(i))\n",
"experiment.log_image(r'output/mnist-normal/morph_example.png'.format(i))\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Checking Final Results "
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1080x432 with 40 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1080x432 with 40 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# The red outlined digits are generated by the adversarial network.\n",
"\n",
"import matplotlib.patches as plot_patch\n",
"\n",
"plt.figure(figsize=(15,6))\n",
"noise = np.random.uniform(-1.0, 1.0, size=[40, 100])\n",
"images_fake = net_generator.predict(noise)\n",
"images_real = x_train[np.random.randint(0, x_train.shape[0], size=40), :, :, :]\n",
"choice_vector = np.random.uniform(0, 1, size=40)\n",
"\n",
"for i in range(40):\n",
" \n",
" if choice_vector[i] > 0.5:\n",
" image = images_fake[i, :, :, :]\n",
" else:\n",
" image = images_real[i]\n",
" image = np.reshape(image, [28, 28])\n",
"\n",
" plt.subplot(4, 10, i+1)\n",
" plt.imshow(image, cmap='gray')\n",
" plt.axis('off')\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"plt.figure(figsize=(15,6))\n",
"\n",
"border = np.zeros((28,28,3))\n",
"border[0,:] = [255,0,0]\n",
"border[:,0] = [255,0,0]\n",
"\n",
"for i in range(40):\n",
" \n",
" if choice_vector[i] > 0.5:\n",
" image = images_fake[i, :, :, :]\n",
" else:\n",
" image = images_real[i]\n",
" image = np.reshape(image, [28, 28])\n",
" \n",
" ax = plt.subplot(4, 10, i+1)\n",
" plt.imshow(image, cmap='gray')\n",
" if choice_vector[i] > 0.5:\n",
" ax.add_patch(plot_patch.Rectangle((0,0), 27, 27, edgecolor=\"red\", linewidth=2, fill=False)) \n",
" plt.axis('off')\n",
"\n",
"plt.tight_layout()\n",
"plt.savefig(r'output/mnist-normal/final_results_sample.png'.format(i))\n",
"experiment.log_image(r'output/mnist-normal/final_results_sample.png'.format(i))\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"COMET INFO: ----------------------------\n",
"COMET INFO: Comet.ml Experiment Summary:\n",
"COMET INFO: Data:\n",
"COMET INFO: url: https://www.comet.ml/ceceshao1/mnist-gan/cf310adacd724bf280323e2eef92d1cd\n",
"COMET INFO: Metrics:\n",
"COMET INFO: acc_adv: 0.35546875\n",
"COMET INFO: acc_dis: 0.7246094\n",
"COMET INFO: loss_adv: 1.024724\n",
"COMET INFO: loss_dis: 0.5381024\n",
"COMET INFO: sys.gpu.0.free_memory: 6392709120\n",
"COMET INFO: sys.gpu.0.gpu_utilization: 0\n",
"COMET INFO: sys.gpu.0.total_memory: 11996954624\n",
"COMET INFO: sys.gpu.0.used_memory: 5604245504\n",
"COMET INFO: Other:\n",
"COMET INFO: dataset_info: [[[[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" ...\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]]\n",
"\n",
"\n",
" [[[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" ...\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]]\n",
"\n",
"\n",
" [[[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" ...\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]]\n",
"\n",
"\n",
" ...\n",
"\n",
"\n",
" [[[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" ...\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]]\n",
"\n",
"\n",
" [[[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" ...\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]]\n",
"\n",
"\n",
" [[[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" ...\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n",
"\n",
" [[0.]\n",
" [0.]\n",
" [0.]\n",
" ...\n",
" [0.]\n",
" [0.]\n",
" [0.]]]]\n",
"COMET INFO: Uploads:\n",
"COMET INFO: assets: 0\n",
"COMET INFO: figures: 0\n",
"COMET INFO: images: 24\n",
"COMET INFO: ----------------------------\n",
"COMET INFO: Uploading stats to Comet before program termination (may take several seconds)\n"
]
}
],
"source": [
"# at the end of your training, end the Comet experiment\n",
"experiment.end()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@christianayala10
Copy link

christianayala10 commented Apr 17, 2019

How the discriminator could be trained (in 4th step) if it's layers have been frozen (in 3rd step)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment