Skip to content

Instantly share code, notes, and snippets.

@Shirataki2
Created June 11, 2018 14:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Shirataki2/86a068ea0a3eff36974ccf9b8ab5a8cf to your computer and use it in GitHub Desktop.
Save Shirataki2/86a068ea0a3eff36974ccf9b8ab5a8cf to your computer and use it in GitHub Desktop.
WGAN-GP-Mnist.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "# Wasserstein GAN-Gradient Penalty"
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-06-10T00:30:08.839501Z",
"end_time": "2018-06-10T00:30:16.684564Z"
},
"trusted": true
},
"cell_type": "code",
"source": "import keras\nfrom keras.layers import Input, Dense, Reshape, Flatten, Dropout\nfrom keras.layers import Conv2D, Deconv2D, UpSampling2D,Conv2DTranspose\nfrom keras.layers import BatchNormalization, Activation, ZeroPadding2D\nfrom keras.optimizers import RMSprop, Adam\nfrom keras.models import Model, Sequential\nfrom keras.layers.advanced_activations import LeakyReLU\nimport keras.backend as K\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom IPython import display\nimport tqdm\nfrom keras.layers.merge import _Merge\nfrom functools import partial",
"execution_count": 1,
"outputs": [
{
"name": "stderr",
"text": "C:\\Users\\tmy19\\Miniconda3\\envs\\tensorflow\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n from ._conv import register_converters as _register_converters\nUsing TensorFlow backend.\n",
"output_type": "stream"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-06-10T00:30:16.686494Z",
"end_time": "2018-06-10T00:30:16.691500Z"
},
"trusted": true
},
"cell_type": "code",
"source": "IMG_SHAPE = (28, 28, 1)\nGENERATOR_FIRST_CHANNELS = 512\nGENERATOR_SECOND_CHANNELS = 128\nLATENT_DIM = 100\nBATCH_SIZE = 32\nLR = 2.434e-4\nTRAIN_RATIO = 5\nGRADIENT_PENALTY_WEIGHT = 10",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-06-10T00:30:16.693469Z",
"end_time": "2018-06-10T00:30:16.711422Z"
},
"trusted": true
},
"cell_type": "code",
"source": "def make_generator():\n s3 = IMG_SHAPE[0]//4\n model = Sequential()\n model.add(Dense(GENERATOR_FIRST_CHANNELS, input_dim=LATENT_DIM,activation='relu'))\n model.add(Dense(GENERATOR_SECOND_CHANNELS*s3*s3))\n model.add(BatchNormalization())\n model.add(LeakyReLU())\n model.add(Reshape((s3, s3, GENERATOR_SECOND_CHANNELS)))\n model.add(Conv2DTranspose(128, (5, 5), strides=2, padding='same'))\n model.add(BatchNormalization(axis=-1))\n model.add(LeakyReLU())\n model.add(Conv2D(128, (5, 5), padding='same'))\n model.add(BatchNormalization(axis=-1))\n model.add(LeakyReLU())\n model.add(Conv2DTranspose(128, (5, 5), strides=2, padding='same'))\n model.add(BatchNormalization(axis=-1))\n model.add(LeakyReLU())\n model.add(Conv2D(IMG_SHAPE[2], (5, 5),padding='same',activation='tanh'))\n return model",
"execution_count": 3,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-06-10T00:30:16.714416Z",
"end_time": "2018-06-10T00:30:18.570448Z"
},
"trusted": true
},
"cell_type": "code",
"source": "generator = make_generator()\ngenerator.summary()",
"execution_count": 4,
"outputs": [
{
"name": "stdout",
"text": "_________________________________________________________________\nLayer (type) Output Shape Param # \n=================================================================\ndense_1 (Dense) (None, 512) 51712 \n_________________________________________________________________\ndense_2 (Dense) (None, 6272) 3217536 \n_________________________________________________________________\nbatch_normalization_1 (Batch (None, 6272) 25088 \n_________________________________________________________________\nleaky_re_lu_1 (LeakyReLU) (None, 6272) 0 \n_________________________________________________________________\nreshape_1 (Reshape) (None, 7, 7, 128) 0 \n_________________________________________________________________\nconv2d_transpose_1 (Conv2DTr (None, 14, 14, 128) 409728 \n_________________________________________________________________\nbatch_normalization_2 (Batch (None, 14, 14, 128) 512 \n_________________________________________________________________\nleaky_re_lu_2 (LeakyReLU) (None, 14, 14, 128) 0 \n_________________________________________________________________\nconv2d_1 (Conv2D) (None, 14, 14, 128) 409728 \n_________________________________________________________________\nbatch_normalization_3 (Batch (None, 14, 14, 128) 512 \n_________________________________________________________________\nleaky_re_lu_3 (LeakyReLU) (None, 14, 14, 128) 0 \n_________________________________________________________________\nconv2d_transpose_2 (Conv2DTr (None, 28, 28, 128) 409728 \n_________________________________________________________________\nbatch_normalization_4 (Batch (None, 28, 28, 128) 512 \n_________________________________________________________________\nleaky_re_lu_4 (LeakyReLU) (None, 28, 28, 128) 0 \n_________________________________________________________________\nconv2d_2 (Conv2D) (None, 28, 28, 1) 3201 \n=================================================================\nTotal params: 4,528,257\nTrainable params: 4,514,945\nNon-trainable params: 13,312\n_________________________________________________________________\n",
"output_type": "stream"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-06-10T00:30:18.572444Z",
"end_time": "2018-06-10T00:30:18.578429Z"
},
"trusted": true
},
"cell_type": "code",
"source": "def make_discriminator():\n model = Sequential()\n model.add(Conv2D(64, (5, 5),padding='same', input_shape=IMG_SHAPE))\n model.add(LeakyReLU())\n model.add(Conv2D(64, (5, 5), kernel_initializer='he_normal', strides=[2, 2]))\n model.add(LeakyReLU())\n model.add(Conv2D(64, (5, 5), kernel_initializer='he_normal', padding='same', strides=[2, 2]))\n model.add(LeakyReLU())\n model.add(Flatten())\n model.add(Dense(1024, kernel_initializer='he_normal'))\n model.add(LeakyReLU())\n model.add(Dense(1, kernel_initializer='he_normal'))\n return model",
"execution_count": 5,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-06-10T00:30:18.581419Z",
"end_time": "2018-06-10T00:30:18.738004Z"
},
"trusted": true
},
"cell_type": "code",
"source": "discriminator = make_discriminator()\ndiscriminator.summary()",
"execution_count": 6,
"outputs": [
{
"name": "stdout",
"text": "_________________________________________________________________\nLayer (type) Output Shape Param # \n=================================================================\nconv2d_3 (Conv2D) (None, 28, 28, 64) 1664 \n_________________________________________________________________\nleaky_re_lu_5 (LeakyReLU) (None, 28, 28, 64) 0 \n_________________________________________________________________\nconv2d_4 (Conv2D) (None, 12, 12, 64) 102464 \n_________________________________________________________________\nleaky_re_lu_6 (LeakyReLU) (None, 12, 12, 64) 0 \n_________________________________________________________________\nconv2d_5 (Conv2D) (None, 6, 6, 64) 102464 \n_________________________________________________________________\nleaky_re_lu_7 (LeakyReLU) (None, 6, 6, 64) 0 \n_________________________________________________________________\nflatten_1 (Flatten) (None, 2304) 0 \n_________________________________________________________________\ndense_3 (Dense) (None, 1024) 2360320 \n_________________________________________________________________\nleaky_re_lu_8 (LeakyReLU) (None, 1024) 0 \n_________________________________________________________________\ndense_4 (Dense) (None, 1) 1025 \n=================================================================\nTotal params: 2,567,937\nTrainable params: 2,567,937\nNon-trainable params: 0\n_________________________________________________________________\n",
"output_type": "stream"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-06-10T00:30:18.739997Z",
"end_time": "2018-06-10T00:30:19.118019Z"
},
"trusted": true
},
"cell_type": "code",
"source": "(X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data()\nX_train = np.concatenate((X_train, X_test), axis=0).reshape((-1,*IMG_SHAPE))",
"execution_count": 7,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-06-10T00:30:19.119978Z",
"end_time": "2018-06-10T00:30:19.492044Z"
},
"trusted": true
},
"cell_type": "code",
"source": "X_train = (X_train.astype(np.float32) - 127.5) / 127.5",
"execution_count": 8,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-06-10T00:30:19.500961Z",
"end_time": "2018-06-10T00:30:19.551824Z"
},
"trusted": true
},
"cell_type": "code",
"source": "print(X_train.min(),X_train.max())",
"execution_count": 10,
"outputs": [
{
"name": "stdout",
"text": "-1.0 1.0\n",
"output_type": "stream"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-06-10T00:30:19.553844Z",
"end_time": "2018-06-10T00:30:19.557829Z"
},
"trusted": true
},
"cell_type": "code",
"source": "def wasserstein_loss(y_true, y_pred):\n return K.mean(y_true* y_pred)",
"execution_count": 11,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-06-10T00:30:19.493978Z",
"end_time": "2018-06-10T00:30:19.498965Z"
},
"trusted": true
},
"cell_type": "code",
"source": "for layer in discriminator.layers:\n layer.trainable = False\ndiscriminator.trainable = False",
"execution_count": 9,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-06-10T00:30:19.560800Z",
"end_time": "2018-06-10T00:30:19.984667Z"
},
"trusted": true
},
"cell_type": "code",
"source": "generator_input = Input(shape=(LATENT_DIM,))\ngenerator_layers = generator(generator_input)\ndiscriminator_layers_for_generator = discriminator(generator_layers)\ngenerator_model = Model([generator_input], [discriminator_layers_for_generator])\ngenerator_model.summary()\ngenerator_model.compile(optimizer=Adam(LR,beta_1=0.5,beta_2=0.9),\n loss=wasserstein_loss)",
"execution_count": 12,
"outputs": [
{
"name": "stdout",
"text": "_________________________________________________________________\nLayer (type) Output Shape Param # \n=================================================================\ninput_1 (InputLayer) (None, 100) 0 \n_________________________________________________________________\nsequential_1 (Sequential) (None, 28, 28, 1) 4528257 \n_________________________________________________________________\nsequential_2 (Sequential) (None, 1) 2567937 \n=================================================================\nTotal params: 7,096,194\nTrainable params: 4,514,945\nNon-trainable params: 2,581,249\n_________________________________________________________________\n",
"output_type": "stream"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-06-10T00:30:19.986662Z",
"end_time": "2018-06-10T00:30:19.991647Z"
},
"trusted": true
},
"cell_type": "code",
"source": "for layer in discriminator.layers:\n layer.trainable = True\nfor layer in generator.layers:\n layer.trainable = False\ndiscriminator.trainable = True\ngenerator.trainable = False",
"execution_count": 13,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-06-10T00:30:19.993643Z",
"end_time": "2018-06-10T00:30:20.007604Z"
},
"trusted": true
},
"cell_type": "code",
"source": "def gradient_penalty_loss(y_true, y_pred, averaged_samples, gradient_penalty_weight):\n gradients = K.gradients(y_pred, averaged_samples)[0]\n gradients_sqr = K.square(gradients)\n gradients_sqr_sum = K.sum(gradients_sqr,axis=np.arange(1, len(gradients_sqr.shape)))\n gradients_l2_norm = K.sqrt(gradients_sqr_sum)\n gradient_penalty = gradient_penalty_weight*K.square(1- gradients_l2_norm)\n return K.mean(gradient_penalty)",
"execution_count": 14,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-06-10T00:30:20.009599Z",
"end_time": "2018-06-10T00:30:20.020596Z"
},
"trusted": true
},
"cell_type": "code",
"source": "class RandomWeightAverage(_Merge):\n def _merge_function(self, inputs):\n weights = K.random_uniform((BATCH_SIZE, 1, 1, 1))\n return (weights * inputs[0]) + ((1-weights) * inputs[1])",
"execution_count": 15,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-06-10T00:31:40.693365Z",
"end_time": "2018-06-10T00:31:40.699348Z"
},
"trusted": true
},
"cell_type": "code",
"source": "def generate_images(generator_model,n,m):\n fig,ax = plt.subplots(ncols=5,nrows=4,figsize=(10,8))\n fig.suptitle('Epoch %4d, Batch %4d'%(n,m), fontsize=20)\n test_image_stack = generator_model.predict(np.random.rand(20, LATENT_DIM))\n test_image_stack = (test_image_stack * 0.5) + 0.5\n test_image_stack = np.clip(test_image_stack,0.0,1.0)\n if test_image_stack.shape[3] == 1:\n test_image_stack = test_image_stack.reshape((-1, IMG_SHAPE[0],IMG_SHAPE[1]))\n for i in range(4):\n for j in range(5):\n ax[i, j].imshow(test_image_stack[i*5+j], 'gray')\n plt.savefig('wgan_images2/%04d-%04d.png'%(n,m))\n plt.show()",
"execution_count": 24,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-06-10T00:30:20.040519Z",
"end_time": "2018-06-10T00:30:20.447499Z"
},
"trusted": true
},
"cell_type": "code",
"source": "real_samples = Input(shape=X_train.shape[1:])\ngenerator_input_for_discriminator = Input(shape=(LATENT_DIM, ))\ngenerated_samples_for_discriminator = generator(generator_input_for_discriminator)\ndiscriminator_output_from_generator = discriminator(generated_samples_for_discriminator)\ndiscriminator_output_from_real_samples = discriminator(real_samples)",
"execution_count": 17,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-06-10T00:30:20.449436Z",
"end_time": "2018-06-10T00:30:20.491376Z"
},
"trusted": true
},
"cell_type": "code",
"source": "averaged_samples = RandomWeightAverage()([real_samples, generated_samples_for_discriminator])\naveraged_samples_out = discriminator(averaged_samples)",
"execution_count": 18,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-06-10T00:30:20.494308Z",
"end_time": "2018-06-10T00:30:20.497311Z"
},
"trusted": true
},
"cell_type": "code",
"source": "partial_gp_loss = partial(gradient_penalty_loss,\n averaged_samples=averaged_samples,\n gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT)\npartial_gp_loss.__name__ = 'gradient_penalty'",
"execution_count": 19,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-06-10T00:30:20.499292Z",
"end_time": "2018-06-10T00:30:20.544170Z"
},
"trusted": true
},
"cell_type": "code",
"source": "discriminator_model = Model([real_samples, generator_input_for_discriminator],\n [discriminator_output_from_real_samples,\n discriminator_output_from_generator,\n averaged_samples_out])\ndiscriminator_model.summary()",
"execution_count": 20,
"outputs": [
{
"name": "stdout",
"text": "__________________________________________________________________________________________________\nLayer (type) Output Shape Param # Connected to \n==================================================================================================\ninput_3 (InputLayer) (None, 100) 0 \n__________________________________________________________________________________________________\ninput_2 (InputLayer) (None, 28, 28, 1) 0 \n__________________________________________________________________________________________________\nsequential_1 (Sequential) (None, 28, 28, 1) 4528257 input_3[0][0] \n__________________________________________________________________________________________________\nrandom_weight_average_1 (Random (None, 28, 28, 1) 0 input_2[0][0] \n sequential_1[2][0] \n__________________________________________________________________________________________________\nsequential_2 (Sequential) (None, 1) 2567937 sequential_1[2][0] \n input_2[0][0] \n random_weight_average_1[0][0] \n==================================================================================================\nTotal params: 7,096,194\nTrainable params: 2,567,937\nNon-trainable params: 4,528,257\n__________________________________________________________________________________________________\n",
"output_type": "stream"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-06-10T00:30:20.546179Z",
"end_time": "2018-06-10T00:30:20.757611Z"
},
"trusted": true
},
"cell_type": "code",
"source": "discriminator_model.compile(optimizer=Adam(LR,beta_1=0.5,beta_2=0.9),\n loss=[wasserstein_loss,\n wasserstein_loss,\n partial_gp_loss])",
"execution_count": 21,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-06-10T00:30:20.759611Z",
"end_time": "2018-06-10T00:30:20.764581Z"
},
"trusted": true
},
"cell_type": "code",
"source": "positive_y = np.ones((BATCH_SIZE, 1), dtype=np.float32)\nnegative_y = -positive_y\ndummy_y = np.zeros((BATCH_SIZE, 1), dtype=np.float32)",
"execution_count": 22,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-06-10T00:31:46.194515Z",
"end_time": "2018-06-10T03:03:06.650184Z"
},
"trusted": true
},
"cell_type": "code",
"source": "num = 50\noffset = 0\ni_len = int(X_train.shape[0] // (BATCH_SIZE * TRAIN_RATIO))\n\ngenerate_images(generator,0,0)\nfor epoch in range(offset,offset+num):\n np.random.shuffle(X_train)\n discriminator_loss = []\n generator_loss = []\n minibatch_size = BATCH_SIZE * TRAIN_RATIO\n\n for i in range(i_len):\n if i % 36 == 0:\n display.clear_output(wait=True)\n if i>0:\n generate_images(generator,epoch,i)\n print(\"{0}/{1} {2:.2f} % \".format(epoch, num+offset, 100*i/i_len), end='')\n print('{0}'.format(i%36+1), end='')\n discriminator_minibatches = X_train[i *\n minibatch_size:(i+1)*minibatch_size]\n for j in range(TRAIN_RATIO):\n print('.', end='')\n image_batch = discriminator_minibatches[j *\n BATCH_SIZE:(j+1)*BATCH_SIZE]\n noise = np.random.rand(BATCH_SIZE, LATENT_DIM).astype(np.float32)\n discriminator_loss.append(discriminator_model.train_on_batch([image_batch, noise],\n [positive_y, negative_y, dummy_y]))\n generator_loss.append(generator_model.train_on_batch(\n np.random.rand(BATCH_SIZE, LATENT_DIM), positive_y))\ndisplay.clear_output(wait=True)\ngenerate_images(generator,num+offset,i_len)\nprint(\"Finished {0}/{1} {2:.2f} % \".format(num+offset, num+offset, 100))",
"execution_count": 25,
"outputs": [
{
"metadata": {},
"data": {
"image/png": "\n",
"text/plain": "<Figure size 720x576 with 20 Axes>"
},
"output_type": "display_data"
},
{
"name": "stdout",
"text": "Finished 50/50 100.00 % \n",
"output_type": "stream"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"varInspector": {
"window_display": false,
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"library": "var_list.py",
"delete_cmd_prefix": "del ",
"delete_cmd_postfix": "",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"library": "var_list.r",
"delete_cmd_prefix": "rm(",
"delete_cmd_postfix": ") ",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
]
},
"kernelspec": {
"name": "conda-env-tensorflow-py",
"display_name": "Python [conda env:tensorflow]",
"language": "python"
},
"language_info": {
"pygments_lexer": "ipython3",
"name": "python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"nbconvert_exporter": "python",
"version": "3.5.5",
"file_extension": ".py",
"mimetype": "text/x-python"
},
"gist": {
"id": "",
"data": {
"description": "WGAN-GP-Mnist.ipynb",
"public": true
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment