Skip to content

Instantly share code, notes, and snippets.

@zaccharieramzi
Last active June 13, 2020 18:09
Show Gist options
  • Save zaccharieramzi/8b3cf272dc6f89ec0f45ba3bbed3b11a to your computer and use it in GitHub Desktop.
Save zaccharieramzi/8b3cf272dc6f89ec0f45ba3bbed3b11a to your computer and use it in GitHub Desktop.
MixedBatchGAN
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "MixedBatchGAN",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyOyqoaz64wZO5KaB/POLxnl",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/zaccharieramzi/8b3cf272dc6f89ec0f45ba3bbed3b11a/mixedbatchgan.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "B2cPftqYbQNF",
"colab_type": "text"
},
"source": [
"# Mixed batch and symmetric discriminators for GAN training"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E3U62YBKbYdP",
"colab_type": "text"
},
"source": [
"This colaboratory notebook reproduces one bit of the ICML 2018 oral \"Mixed batches and symmetric discriminators for GAN training\", Lucas et al.\n",
"This is an unofficial implementation in TensorFlow.\n",
"\n",
"The basic idea is to go further in the idea that GAN learn a distribution and rather than looking at only one sample of the learned distribution (i.e. one image), the discriminator will have access to a batch of the learned distribution.\n",
"\n",
"Key tricks are:\n",
"- to mix this batch with original images, and have the discriminator learn the proportion of real images;\n",
"- to design a discriminator with batch permutation invariance embedded to prevent the network from learning this aspect."
]
},
{
"cell_type": "code",
"metadata": {
"id": "0fS31XnZKQT0",
"colab_type": "code",
"cellView": "form",
"colab": {}
},
"source": [
"#@title\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"from tensorflow.keras.layers import Layer, Dense, Activation\n",
"from tensorflow.keras.models import Model, Sequential\n",
"from tensorflow.keras.optimizers import Adam\n",
"from tensorflow_addons.callbacks import TQDMProgressBar"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "ebzpf6u8KrDz",
"colab_type": "text"
},
"source": [
"# DATA"
]
},
{
"cell_type": "code",
"metadata": {
"id": "-Tjy5jorKtVh",
"colab_type": "code",
"colab": {}
},
"source": [
"def square_centers(square_length):\n",
" square_side = np.arange(square_length)\n",
" square_side = square_side - (square_length-1) / 2\n",
" square_side = square_side / np.sqrt(square_length)\n",
" square_centers = np.meshgrid(square_side, square_side)\n",
" return square_centers\n",
"\n",
"def generate_sample_from_square(square_length, std):\n",
" square_side = np.arange(square_length)\n",
" square_side = square_side - (square_length-1) / 2\n",
" square_side = square_side / np.sqrt(square_length)\n",
" while True:\n",
" position = np.random.choice(square_side, size=2)\n",
" point = np.random.normal(loc=position, scale=std)\n",
" yield point"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "A2Nl2irqL4zy",
"colab_type": "code",
"colab": {}
},
"source": [
"points_ds = tf.data.Dataset.from_generator(\n",
" generate_sample_from_square,\n",
" tf.float32,\n",
" tf.TensorShape([2]),\n",
" args=(5, 0.01),\n",
").batch(128)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "eISiQtTLL6hO",
"colab_type": "code",
"colab": {}
},
"source": [
"def generate_latent_variable(var_dim, batch_size=1, std=1.):\n",
" return tf.random.normal((batch_size, var_dim), stddev=std)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "_wghRJ8UZVr3",
"colab_type": "text"
},
"source": [
"## Visualization"
]
},
{
"cell_type": "code",
"metadata": {
"id": "FV_u-SSJZlr7",
"colab_type": "code",
"colab": {}
},
"source": [
"data = next(points_ds.as_numpy_iterator())\n",
"centers = square_centers(5)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "URT6O3xKZYyJ",
"colab_type": "code",
"outputId": "1771f660-30a7-493d-a99c-a7d02443ebf4",
"cellView": "form",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 283
}
},
"source": [
"#@title\n",
"plt.figure()\n",
"plt.scatter(data[:, 0], data[:, 1])\n",
"plt.scatter(centers[0], centers[1])"
],
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<matplotlib.collections.PathCollection at 0x7f10c019a2e8>"
]
},
"metadata": {
"tags": []
},
"execution_count": 6
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LvH6uSsXOI16",
"colab_type": "text"
},
"source": [
"# MODELS"
]
},
{
"cell_type": "code",
"metadata": {
"id": "nE7zf4VXXS-o",
"colab_type": "code",
"colab": {}
},
"source": [
"def kl_loss(ratio_true, ratio_pred):\n",
" kl = tf.keras.losses.KLDivergence()\n",
" kl1 = kl(ratio_true, ratio_pred)\n",
" kl2 = kl(1 - ratio_true, 1 - ratio_pred)\n",
" return kl1 + kl2"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "NFH-7uM4Oy-X",
"colab_type": "code",
"colab": {}
},
"source": [
"class PermEquivLayer(Layer):\n",
" def __init__(self, units=512, activation='relu', **kwargs):\n",
" super(PermEquivLayer, self).__init__(**kwargs)\n",
" self.activation = Activation(activation)\n",
" self.parallel_dense = Dense(units)\n",
" self.batch_dense = Dense(units, use_bias=False)\n",
"\n",
" def call(self, inputs):\n",
" parallel_outputs = self.parallel_dense(inputs)\n",
" rho = tf.reduce_mean(inputs, axis=0, keepdims=True)\n",
" batch_outputs = self.batch_dense(rho)\n",
" outputs = self.activation(parallel_outputs + batch_outputs)\n",
" return outputs\n",
"\n",
"class PermEquivDiscriminator(Model):\n",
" def __init__(self, n_layers=3, units=512, activation='relu', **kwargs):\n",
" super(PermEquivDiscriminator, self).__init__(**kwargs)\n",
" self.perm_equiv_layers = [\n",
" PermEquivLayer(units=units, activation=activation)\n",
" for _ in range(n_layers)\n",
" ]\n",
" self.perm_equiv_layers.append(PermEquivLayer(units=1, activation='sigmoid'))\n",
"\n",
" def call(self, inputs):\n",
" outputs = inputs\n",
" for perm_equiv_layer in self.perm_equiv_layers:\n",
" outputs = perm_equiv_layer(outputs)\n",
" outputs = tf.reduce_mean(outputs, axis=0)\n",
" return outputs\n",
"\n",
"\n",
"class BGAN(Model):\n",
" # inspired by https://keras.io/examples/generative/dcgan_overriding_train_step/\n",
" def __init__(self, gamma=0.3, units = 512, **kwargs):\n",
" super(BGAN, self).__init__(**kwargs)\n",
" self.discriminator = PermEquivDiscriminator()\n",
" self.generator = Sequential([Dense(units, activation='relu') for _ in range(3)])\n",
" self.generator.add(Dense(2))\n",
" self.gamma = gamma\n",
" self.i_iter = tf.Variable(0)\n",
" self.gen_train_prop = 5\n",
" self.fix_disc = False\n",
"\n",
" def call(self, inputs):\n",
" return self.generator(inputs)\n",
"\n",
" def compile(self, d_optimizer, g_optimizer):\n",
" super(BGAN, self).compile()\n",
" self.d_optimizer = d_optimizer\n",
" self.g_optimizer = g_optimizer\n",
"\n",
" def train_step(self, data):\n",
" # data is the real samples\n",
" batch_size = tf.shape(data)[0]\n",
" random_latent_vectors = generate_latent_variable(2, batch_size) \n",
" # draw p and beta\n",
" p_inf = tf.random.uniform((1,), maxval=self.gamma)\n",
" p_sup = tf.random.uniform((1,), minval=1 - self.gamma, maxval=1.)\n",
" which_p = tf.greater(tf.random.uniform((1,), minval=0, maxval=1.), 0.5)\n",
" p = tf.cond(which_p, lambda: p_sup, lambda: p_inf)\n",
" beta = tf.cast(tf.greater(tf.random.uniform([batch_size], minval=0, maxval=1.), p), tf.float32)\n",
" # ratio is the number of true samples over the number of samples\n",
" ratio_true = tf.reduce_sum(beta) / tf.cast(batch_size, tf.float32)\n",
" with tf.GradientTape(persistent=True) as tape:\n",
" # Decode to fake samples\n",
" generated_samples = self.generator(random_latent_vectors)\n",
" # Combine with real samples\n",
" combined_samples = beta[:, None] * data + (1-beta[:, None]) * generated_samples\n",
" ratio_pred = self.discriminator(combined_samples)\n",
" tf.debugging.check_numerics(ratio_pred, 'ratio_pred')\n",
" d_loss = kl_loss(ratio_true[None], ratio_pred)\n",
" # if ratio_pred is higher, then the discriminator thinks\n",
" # that there is a higher proportion of true samples\n",
" # Therefore the generator must maximize an increasing function of ratio_pred\n",
" # and minimize a deacreasing function of ratio_pred\n",
" g_loss = - tf.math.log(ratio_pred)\n",
" tf.debugging.check_numerics(d_loss, 'd_loss')\n",
" tf.debugging.check_numerics(g_loss, 'g_loss')\n",
" if not self.fix_disc:\n",
" grads = tape.gradient(d_loss, self.discriminator.trainable_weights)\n",
" self.d_optimizer.apply_gradients(\n",
" zip(grads, self.discriminator.trainable_weights)\n",
" )\n",
" # we train the generator only one fifth of the time\n",
" if self.i_iter % self.gen_train_prop == 0 or self.fix_disc: \n",
" grads = tape.gradient(g_loss, self.generator.trainable_weights)\n",
" self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))\n",
" self.i_iter.assign_add(1)\n",
" return {\"d_loss\": d_loss, \"g_loss\": g_loss}\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "K1ROWNwZV7_t",
"colab_type": "text"
},
"source": [
"# TRAINING"
]
},
{
"cell_type": "code",
"metadata": {
"id": "cH6YbAiMaMVN",
"colab_type": "code",
"colab": {}
},
"source": [
"model = BGAN()\n",
"model.compile(Adam(1e-3), Adam(1e-3))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "9VkwWESvaZP9",
"colab_type": "code",
"colab": {}
},
"source": [
"history = model.fit(points_ds, epochs=10000, steps_per_epoch=5, verbose=0)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "dwf1NrubcWGR",
"colab_type": "code",
"cellView": "form",
"outputId": "b5eede6b-7a6b-4f7f-cc82-63936d330193",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 281
}
},
"source": [
"#@title\n",
"fig, axs = plt.subplots(1, 2)\n",
"axs[0].plot(history.history['d_loss'], label='d')\n",
"axs[0].set_title('Discriminator loss')\n",
"axs[1].plot(history.history['g_loss'], label='g')\n",
"axs[1].set_title('Generator loss');"
],
"execution_count": 11,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MWhUMOPsiskX",
"colab_type": "text"
},
"source": [
"# EVALUATE"
]
},
{
"cell_type": "code",
"metadata": {
"id": "kwgyusrliuo0",
"colab_type": "code",
"colab": {}
},
"source": [
"random_latent_vectors = generate_latent_variable(2, 1024)\n",
"generated_samples = model(random_latent_vectors)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Ysfrr6-Swjro",
"colab_type": "code",
"outputId": "8b57b2f1-29ce-4370-c4bc-a30f9fc3d850",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
}
},
"source": [
"generated_samples.shape"
],
"execution_count": 13,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"TensorShape([1024, 2])"
]
},
"metadata": {
"tags": []
},
"execution_count": 13
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "4R8nANK6wf2F",
"colab_type": "code",
"cellView": "form",
"outputId": "dfb18dd2-ae66-4817-bb18-0621f8bcf107",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 299
}
},
"source": [
"#@title\n",
"plt.figure()\n",
"plt.scatter(generated_samples[:, 0], generated_samples[:, 1])\n",
"plt.scatter(centers[0], centers[1])\n",
"plt.title('Generated samples')"
],
"execution_count": 14,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Text(0.5, 1.0, 'Generated samples')"
]
},
"metadata": {
"tags": []
},
"execution_count": 14
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "eU8Lbw8iS2zu",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment