Skip to content

Instantly share code, notes, and snippets.

@cjauvin
Created September 17, 2018 13:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cjauvin/5ac17dbc4665a2719f7d78bf1d51c450 to your computer and use it in GitHub Desktop.
Save cjauvin/5ac17dbc4665a2719f7d78bf1d51c450 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"\n",
"tf.reset_default_graph()\n",
"\n",
"M = 10 # minibatch size\n",
"K = 10 # rows: number of z components\n",
"N = 10 # columns: number of possible discrete states for a z component \n",
"\n",
"X_a = tf.placeholder(tf.float32, [M, 14 * 28]) # half image\n",
"hidden1_a = tf.contrib.layers.fully_connected(X_a, 300, activation_fn=tf.nn.relu)\n",
"hidden2_a = tf.contrib.layers.fully_connected(hidden1_a, 100, activation_fn=tf.nn.relu)\n",
"logits_a = tf.reshape(hidden2_a, [M, K, N])\n",
"\n",
"X_b = tf.placeholder(tf.float32, [M, 14 * 28]) # half image\n",
"hidden1_b = tf.contrib.layers.fully_connected(X_b, 300, activation_fn=tf.nn.relu)\n",
"hidden2_b = tf.contrib.layers.fully_connected(hidden1_b, 100, activation_fn=tf.nn.relu)\n",
"logits_b = tf.reshape(hidden2_b, [M, K, N])\n",
"\n",
"#logits_a = tf.Variable(tf.random_uniform(shape=(M, K, N)))\n",
"#logits_b = tf.Variable(tf.random_uniform(shape=(M, K, N)))\n",
"logits_p = tf.Variable(tf.random_uniform(shape=(K, N)))\n",
"\n",
"# These can be precomputed\n",
"lse_a = tf.nn.log_softmax(logits_a, axis=1)\n",
"lse_b = tf.nn.log_softmax(logits_b, axis=1)\n",
"lse_p = tf.nn.log_softmax(logits_p, axis=1)\n",
"\n",
"def agreement(i, j):\n",
" a = logits_a[i] - lse_a[i] + logits_b[j] - lse_b[j] - logits_p + lse_p # (K, N) matrix\n",
" # LSE-reduce rows first, and then reduce the resulting k-size vector to a scalar\n",
" return tf.reduce_sum(tf.reduce_logsumexp(a, axis=1))\n",
"\n",
"# First term of the Holy Loss: \\sum_n A(La, Lb), for all diag pairs (there are M)\n",
"\n",
"c = lambda i, a: i < M\n",
"b = lambda i, a: [i + 1, a + agreement(i, i)] # accumulate agreement values\n",
"_, holy_loss_first_term = tf.while_loop(c, b, [0, 0.])\n",
"holy_loss_first_term = -holy_loss_first_term / M\n",
"\n",
"# Second term of the Holy Loss: LSE A(La, Lb) over all non-diag pairs (there are M / (M - 1)) \n",
"\n",
"def inner_while_loop(i, r):\n",
" c = lambda j, ea: j < M\n",
" # accumulate exp(agreement) values (if non-diag)\n",
" b = lambda j, ea: [j + 1, ea + tf.cond(tf.not_equal(i, j), lambda: tf.exp(agreement(i, j)), lambda: 0.)]\n",
" _, t = tf.while_loop(c, b, [0, r])\n",
" return i + 1, t\n",
"\n",
"c = lambda i, ea: i < M\n",
"_, holy_loss_second_term = tf.while_loop(c, inner_while_loop, [0, 0.])\n",
"holy_loss_second_term = tf.log(holy_loss_second_term) - tf.log(M / (M - 1))\n",
"\n",
"# The full Holy Loss\n",
"holy_loss = holy_loss_first_term + holy_loss_second_term\n",
"\n",
"#train_op = tf.train.AdamOptimizer().minimize(holy_loss)\n",
"train_op = tf.train.GradientDescentOptimizer(0.01).minimize(holy_loss)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0xb39d43780>"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAUgAAAD8CAYAAAAVOD3kAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADbxJREFUeJzt3W+IXYWZx/HfL9pASFWMbqaDDRrWuG4oMYUxrAQ0JTakUoh9oVSwBFYyReqLwr4w6IsWZEHB7dIXZTW1w6SY2hbiYF7UtiEvlMVNyEyRmpjESEjbMSFpsLJZFUv02RdzxoxxnnNu7r9z753vB+Tee55z7nnQx1/OOTn3XkeEAACft6juBgCgVxGQAJAgIAEgQUACQIKABIAEAQkACQISABIEJAAkCEgASFzZzZ3Z5mM7/edcRPxD3U30Oma7/0SEq9Zp6QjS9mbbx2y/bXt7K++FnvWnuhuoA7MNqYWAtH2FpJ9I+oak1ZIesL26XY0BdWG2MauVI8h1kt6OiBMR8XdJv5S0pT1tAbVitiGptYC8QdJf5ryeLpZ9hu1R25O2J1vYF9BNzDYktfaXNPNd4PzcheqI2CFph8SFbPQNZhuSWjuCnJa0Ys7rL0s61Vo7QE9gtiGptYA8KGmV7ZW2F0v6tqQ97WkLqBWzDUktnGJHxAXbj0j6naQrJI1FxOG2dQbUhNnGLHfzJxe4TtOXpiJipO4meh2z3X86fqM4AAwyAhIAEgQkACQISABIEJAAkCAgASBBQAJAgoAEgAQBCQAJAhIAEgQkACQISABIEJAAkCAgASBBQAJAgoAEgAQBCQAJAhIAEgQkACQISABIEJAAkCAgASBBQAJAgoAEgAQBCQAJAhIAEgQkACQISABIEJAAkLiylY1tn5R0XtLHki5ExEg7mgLqtpBme+PGjWlt165dpdveddddpfVjx4411VOvaCkgC1+LiHNteB+g1zDbCxyn2ACQaDUgQ9LvbU/ZHm1HQ0CPYLbR8in2+og4ZXu5pL22j0bEq3NXKIaLAUO/YbbR2hFkRJwqHs9KmpC0bp51dkTEyCBf5MbgYbYhtRCQtpfavmr2uaRNkg61qzGgLsw2ZrVyij0kacL27Pv8IiJ+25augHox25DUQkBGxAlJt7Wxl4668847S+vXXXddaX1iYqKd7aCH9dtst+r2229PawcPHuxiJ72H23wAIEFAAkCCgASABAEJAAkCEgASBCQAJNrxbT59YcOGDaX1VatWlda5zQf9atGi8uOglStXprUbb7yxdNviXtGBxREkACQISABIEJAAkCAgASBBQAJAgoAEgAQBCQCJBXMf5NatW0vrr732Wpc6AbpreHi4tL5t27a09vzzz5due/To0aZ66hccQQJAgoAEgAQBCQAJAhIAEgQkACQISABIEJAAkFgw90FWfSceMKiee+65prc9fvx4GzvpP6QGACQISABIEJAAkCAgASBBQAJAgoAEgAQBCQCJyvsgbY9J+qaksxHxlWLZMkm/knSTpJOS7o+Iv3WuzWpr1qwprQ8NDXWpE/SLfpntVl1zzTVNb7t37942dtJ/GjmCHJe0+ZJl2yXti4hVkvYVr4F+My5mGyUqAzIiXpX07iWLt0jaWTzfKeneNvcFdByzjSrNXoMciojTklQ8Lm9fS0CtmG18quOfxbY9Kmm00/sBuo3ZHnzNHkGesT0sScXj2WzFiNgRESMRMdLkvoBuYrbxqWYDco+k2Z8J3Crppfa0A9SO2canKgPS9guS/kfSP9metv2QpCclfd32cUlfL14DfYXZRpXKa5AR8UBS2tjmXlpyzz33lNaXLFnSpU7QL/pltqtU3eO7cuXKpt/7nXfeaXrbQcAnaQAgQUACQIKABIAEAQkACQISABIEJAAkBuZnX2+99daWtj98+HCbOgG66+mnny6tV90G9NZbb6W18+fPN9XToOAIEgASBCQAJAhIAEgQkACQICABIEFAAkCCgASAxMDcB9mqgwcP1t0CBtjVV19dWt+8+dIfV7zowQcfLN1206ZNTfU064knnkhr7733Xkvv3e84ggSABAEJAAkCEgASBCQAJAhIAEgQkACQICABIMF9kIVly5bVtu/bbruttG67tH733XentRUrVpRuu3jx4tL6ww8/XFrHjLVr1+qVV15J64sWlR+LfPjhh2ntwIEDpdt+9NFHpfUrryz/33xqaqq0vpBxBAkACQISABIEJAAkCEgASBCQAJAgIAEgQUACQKLyPkjbY5K+KelsRHylWPZDSdsk/bVY7bGI+E2nmmxE2X1kkhQRpfVnn322tP74449fdk+NWrNmTWm96j7ICxcupLUPPvigdNs333yztD7I2jnb586d09jYWFqfnJws3b7sHsozZ86Ubjs9PV1aX7JkSWn96NGjpfWFrJEjyHFJ832b539GxNrin1rDEWjSuJhtlKgMyIh4VdK7XegF6CpmG1VauQb5iO0/2h6zfW3bOgLqx2xDUvMB+V+S/lHSWkmnJf1HtqLtUduTtssvwgC9oanZfv/997vVH7qoqYCMiDMR8XFEfCLpp5LWlay7IyJGImKk2SaBbml2tpcuXdq9JtE1TQWk7eE5L78l6VB72gHqxWxjrkZu83lB0gZJ19uelvQDSRtsr5UUkk5K+m4HewQ6gtlGFVfdH9jWndnd29klHn300dL6+vXru9TJ5ZuYmCitHzlyJK3t37+/1d1PcXmkWidne3R0tLT+zDPPlNZPnDhRWr/55psvu6dBEBHlNxiLT9IAQIqABIAEAQkACQISABIEJAAkCEgASCyYn3196qmn6m4BaMrGjRtb2n737t1t6mTh4QgSABIEJAAkCEgASBCQAJAgIAEgQUACQIKABIDEgrkPElioqr4uDzmOIAEgQUACQIKABIAEAQkACQISABIEJAAkCEgASBCQAJAgIAEgQUACQIKABIAEAQkACQISABIEJAAkCEgASFR+H6TtFZJ+LulLkj6RtCMifmx7maRfSbpJ0klJ90fE3zrXKtBegzLbtkvrt9xyS2l9//797WxnoDRyBHlB0r9FxD9L+hdJ37O9WtJ2SfsiYpWkfcVroJ8w2yhVGZARcToi/lA8Py/piKQbJG2RtLNYbaekezvVJNAJzDaqXNY1SNs3SfqqpAOShiLitDQzaJKWt7s5oFuYbcyn4d+ksf1FSbslfT8i/rfqusec7UYljTbXHtB5zDYyDR1B2v6CZgZoV0S8WCw+Y3u4qA9LOjvfthGxIyJGImKkHQ0D7cRso0xlQHrmj9OfSToSET+aU9ojaWvxfKukl9rfHtA5zDaqNHKKvV7SdyS9Yfv1Ytljkp6U9GvbD0n6s6T7OtMi0DEDMdsRUVpftIjbnZtVGZAR8d+SsosyG9vbDtA9zDaq8EcLACQISABIEJAAkCAgASBBQAJAgoAEgETDHzUE0J/uuOOO0vr4+Hh3GulDHEECQIKABIAEAQkACQISABIEJAAkCEgASBCQAJDgPkigzzX6ExG4fBxBAkCCgASABAEJAAkCEgASBCQAJAhIAEgQkACQ4D5IoMe9/PLLpfX77uvpn+3uaxxBAkCCgASABAEJAAkCEgASBCQAJAhIAEgQkACQcESUr2CvkPRzSV+S9ImkHRHxY9s/lLRN0l+LVR+LiN9UvFf5ztCLpiJipO4mOoHZXtgiovKLNBsJyGFJwxHxB9tXSZqSdK+k+yX9X0Q83WhDDFFfGuSAZLYXsEYCsvKTNBFxWtLp4vl520ck3dB6e0C9mG1UuaxrkLZvkvRVSQeKRY/Y/qPtMdvXJtuM2p60PdlSp0AHMduYT+Up9qcr2l+U9Iqkf4+IF20PSTonKSQ9oZlTlX+teA9OQ/rPwJ5iz2K2F6ZGTrEbOoK0/QVJuyXtiogXizc/ExEfR8Qnkn4qaV0rzQJ1YLZRpjIgPfOTaT+TdCQifjRn+fCc1b4l6VD72wM6h9lGlUa+7my9pO9IesP268WyxyQ9YHutZk5DTkr6bkc6BDqH2Uaphq9BtmVnXKfpRwN/DbIdmO3+07ZrkACwEBGQAJAgIAEgQUACQIKABIAEAQkACQISABIEJAAkCEgASBCQAJAgIAEgQUACQIKABIAEAQkAiUa+D7Kdzkn605zX1xfLelGv9tbtvm7s4r762dzZ7tXZkehtVkNz3dXvg/zczu3JXv2uwV7trVf7wkW9/N+I3i4Pp9gAkCAgASBRd0DuqHn/ZXq1t17tCxf18n8jersMtV6DBIBeVvcRJAD0rFoC0vZm28dsv217ex09ZGyftP2G7ddtT9bcy5jts7YPzVm2zPZe28eLx2vr7BGfxWw33EtfzHbXA9L2FZJ+IukbklZr5jeIV3e7jwpfi4i1PXDLwbikzZcs2y5pX0SskrSveI0ewGxflnH1wWzXcQS5TtLbEXEiIv4u6ZeSttTQR8+LiFclvXvJ4i2SdhbPd0q6t6tNoQyz3aB+me06AvIGSX+Z83q6WNYrQtLvbU/ZHq27mXkMRcRpSSoel9fcDy5itlvTc7Pd7Y8aSpLnWdZLf5W+PiJO2V4uaa/to8WfdkAVZnvA1HEEOS1pxZzXX5Z0qoY+5hURp4rHs5ImNHPa1EvO2B6WpOLxbM394CJmuzU9N9t1BORBSatsr7S9WNK3Je2poY/Psb3U9lWzzyVtknSofKuu2yNpa/F8q6SXauwFn8Vst6bnZrvrp9gRccH2I5J+J+kKSWMRcbjbfSSGJE3Ylmb+3fwiIn5bVzO2X5C0QdL1tqcl/UDSk5J+bfshSX+WdF9d/eGzmO3G9cts80kaAEjwSRoASBCQAJAgIAEgQUACQIKABIAEAQkACQISABIEJAAk/h/QV9BsToS1tgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from keras.datasets import mnist\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"(X_train, y_train), (X_test, y_test) = mnist.load_data()\n",
"X_train = X_train / 255.\n",
"\n",
"X_train_a = X_train[:, :, :14].reshape((60000, 392))\n",
"X_train_b = X_train[:, :, 14:].reshape((60000, 392))\n",
"\n",
"plt.subplot(1, 2, 1)\n",
"plt.imshow(X_train_a[2].reshape((28, 14)), cmap='gray')\n",
"plt.subplot(1, 2, 2)\n",
"plt.imshow(X_train_b[2].reshape((28, 14)), cmap='gray')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 4.27146275583903\n"
]
}
],
"source": [
"n_epochs = 10\n",
"n_train_batches = len(X_train) // M\n",
"\n",
"with tf.Session(graph=tf.get_default_graph()) as sess:\n",
" sess.run(tf.global_variables_initializer())\n",
"\n",
" for epoch in range(n_epochs): \n",
" epoch_loss = 0\n",
" for i in range(n_train_batches):\n",
" j = i * M\n",
" X_train_batch_a = X_train_a[j:(j + M)]\n",
" X_train_batch_b = X_train_b[j:(j + M)]\n",
" res = sess.run([train_op, holy_loss], feed_dict={X_a: X_train_batch_a, X_b: X_train_batch_b})\n",
" epoch_loss += res[1]\n",
" print(epoch, epoch_loss / n_train_batches)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment