\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import holoviews as hv\n", "%load_ext holoviews.ipython" ] }, { "cell_type": "code", "execution_count": 188, "metadata": { "collapsed": true }, "outputs": [], "source": [ "%output size=200" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def sample_gumbel(shape, eps=1e-20): \n", " \"\"\"Sample from Gumbel(0, 1)\"\"\"\n", " U = tf.random_uniform(shape,minval=0,maxval=1)\n", " return -tf.log(-tf.log(U + eps) + eps)\n", "\n", "def gumbel_softmax_sample(logits, temperature): \n", " \"\"\" Draw a sample from the Gumbel-Softmax distribution\"\"\"\n", " y = logits + sample_gumbel(tf.shape(logits))\n", " return tf.nn.softmax( y / temperature)\n", "\n", "def gumbel_softmax(logits, temperature, hard=False):\n", " \"\"\"Sample from the Gumbel-Softmax distribution and optionally discretize.\n", " Args:\n", " logits: [batch_size, n_class] unnormalized log-probs\n", " temperature: non-negative scalar\n", " hard: if True, take argmax, but differentiate w.r.t. soft sample y\n", " Returns:\n", " [batch_size, n_class] sample from the Gumbel-Softmax distribution.\n", " If hard=True, then the returned sample will be one-hot, otherwise it will\n", " be a probabilitiy distribution that sums to 1 across classes\n", " \"\"\"\n", " y = gumbel_softmax_sample(logits, temperature)\n", " if hard:\n", " k = tf.shape(logits)[-1]\n", " #y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype)\n", " y_hard = tf.cast(tf.equal(y,tf.reduce_max(y,1,keep_dims=True)),y.dtype)\n", " y = tf.stop_gradient(y_hard - y) + y\n", " return y" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 2. Build Model" ] }, { "cell_type": "code", "execution_count": 318, "metadata": { "collapsed": true }, "outputs": [], "source": [ "tf.reset_default_graph()" ] }, { "cell_type": "code", "execution_count": 319, "metadata": { "collapsed": true }, "outputs": [], "source": [ "K=10 # number of classes\n", "N=30 # number of categorical distributions" ] }, { "cell_type": "code", "execution_count": 320, "metadata": {}, "outputs": [], "source": [ "BATCH_SIZE=500\n", "M = BATCH_SIZE\n", "# input image x (shape=(batch_size,784))\n", "x = tf.placeholder(tf.float32,[None,784])\n", "\n", "# variational posterior q(y|x), i.e. the encoder (shape=(batch_size,200))\n", "net = slim.stack(x,slim.fully_connected,[512,256])\n", "\n", "# unnormalized logits for N separate K-categorical distributions (shape=(batch_size*N,K))\n", "logits_y = tf.reshape(slim.fully_connected(net,K*N,activation_fn=None),[-1,K])\n", "q_y = tf.nn.softmax(logits_y)\n", "log_q_y = tf.log(q_y+1e-20)\n", "\n", "# temperature\n", "tau = tf.Variable(5.0,name=\"temperature\")\n", "\n", "# at the same time, we need to parameterise our nested categorical sample\n", "rho = slim.fully_connected(net, 1, activation_fn=tf.nn.sigmoid)\n", "rho = (1-1e-3)*rho\n", "k = tf.range(N, dtype=tf.float32)\n", "nested_up = (tf.pow(rho,k)*(1.0-rho))\n", "nested_p = nested_up/tf.reduce_sum(nested_up, axis=1, keep_dims=True)\n", "nested_logp = tf.reshape(tf.log(nested_p+1e-8), [M,N])\n", "z = gumbel_softmax(nested_logp, tau, hard=False)\n", "z = (1.0-tf.cumsum(z, axis=1))\n", "\n", "# sample and reshape back (shape=(batch_size,N,K))\n", "# set hard=True for ST Gumbel-Softmax\n", "y = tf.reshape(gumbel_softmax(logits_y,tau,hard=False),[-1,N,K])\n", "# apply the soft nested dropout mask\n", "y_masked = y*tf.reshape(z, [M,N,1])\n", "\n", "# generative model p(x|y), i.e. the decoder (shape=(batch_size,200))\n", "net = slim.stack(slim.flatten(y_masked),slim.fully_connected,[256,512])\n", "logits_x = slim.fully_connected(net,784,activation_fn=None)\n", "\n", "# (shape=(batch_size,784))\n", "p_x = Bernoulli(logits=logits_x)" ] }, { "cell_type": "code", "execution_count": 321, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From :2: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.\n", "Instructions for updating:\n", "Use tf.global_variables_initializer instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From :2: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.\n", "Instructions for updating:\n", "Use tf.global_variables_initializer instead.\n" ] } ], "source": [ "sess=tf.InteractiveSession()\n", "init_op=tf.initialize_all_variables()\n", "sess.run(init_op)" ] }, { "cell_type": "code", "execution_count": 322, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Extracting /tmp/train-images-idx3-ubyte.gz\n", "Extracting /tmp/train-labels-idx1-ubyte.gz\n", "Extracting /tmp/t10k-images-idx3-ubyte.gz\n", "Extracting /tmp/t10k-labels-idx1-ubyte.gz\n" ] } ], "source": [ "# get data\n", "data = input_data.read_data_sets('/tmp/', one_hot=True).train " ] }, { "cell_type": "code", "execution_count": 323, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ ":AdjointLayout\n", " :Image [x,y] (z)\n", " :Histogram [z] (Frequency)" ] }, "execution_count": 323, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%output size=200\n", "np_x,np_y=data.next_batch(BATCH_SIZE)\n", "hv.Image(sess.run(z,{x:np_x, rho: np.ones((BATCH_SIZE,1))*0.9}).squeeze()).hist()" ] }, { "cell_type": "code", "execution_count": 324, "metadata": {}, "outputs": [], "source": [ "# loss and train ops\n", "kl_tmp = tf.reshape(q_y*(log_q_y-tf.log(1.0/K)),[-1,N,K])\n", "KL = tf.reduce_sum(kl_tmp,[1,2])\n", "# nested kl to arbitrarily chosen geometric\n", "prior_rho = 0.8\n", "ref_ugeo = tf.pow(prior_rho,k)*(1.0-prior_rho)\n", "ref_geo = ref_ugeo/tf.reduce_sum(ref_ugeo)\n", "log_ref_geo = tf.reshape(tf.log(ref_geo+1e-8), [1, N])\n", "nested_kl = nested_p*(nested_logp - log_ref_geo)\n", "nKL = tf.reduce_sum(nested_kl, [1])\n", "elbo=tf.reduce_sum(p_x.log_prob(x),1) - KL - nKL" ] }, { "cell_type": "code", "execution_count": 325, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From :4: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.\n", "Instructions for updating:\n", "Use tf.global_variables_initializer instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From :4: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.\n", "Instructions for updating:\n", "Use tf.global_variables_initializer instead.\n" ] } ], "source": [ "loss=tf.reduce_mean(-elbo)\n", "lr=tf.constant(0.001)\n", "train_op=tf.train.AdamOptimizer(learning_rate=lr).minimize(loss,var_list=slim.get_model_variables())\n", "init_op=tf.initialize_all_variables()" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "# 3. Train" ] }, { "cell_type": "code", "execution_count": 326, "metadata": { "collapsed": true }, "outputs": [], "source": [ "NUM_ITERS=50000\n", "tau0=1.0 # initial temperature\n", "np_temp=tau0\n", "np_lr=0.001\n", "ANNEAL_RATE=0.00003\n", "MIN_TEMP=0.5" ] }, { "cell_type": "code", "execution_count": 327, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from tqdm import tqdm_notebook as tqdm" ] }, { "cell_type": "code", "execution_count": 328, "metadata": { "scrolled": false }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9757a8c471bc4647b5ddd3504b49ade8" } }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Step 5000, ELBO: -108.560\n", "Step 10000, ELBO: -103.935\n", "Step 15000, ELBO: -102.148\n", "Step 20000, ELBO: -104.645\n", "Step 25000, ELBO: -105.826\n", "Step 30000, ELBO: -105.118\n", "Step 35000, ELBO: -106.743\n", "Step 40000, ELBO: -99.907\n", "Step 45000, ELBO: -104.593\n", "\n" ] } ], "source": [ "dat=[]\n", "sess.run(init_op)\n", "for i in tqdm(range(1,NUM_ITERS)):\n", " np_x,np_y=data.next_batch(BATCH_SIZE)\n", " _,np_loss=sess.run([train_op,loss],{\n", " x:np_x,\n", " tau:np_temp,\n", " lr:np_lr\n", " })\n", " if i % 100 == 1:\n", " dat.append([i,np_temp,np_loss])\n", " if i % 1000 == 1:\n", " np_temp=np.maximum(tau0*np.exp(-ANNEAL_RATE*i),MIN_TEMP)\n", " np_lr*=0.9\n", " if (i+1) % 5000 == 1:\n", " print('Step %d, ELBO: %0.3f' % (i,-np_loss))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## save to animation\n", "\n", "Saving the output of the learnt model to an animation and then loading it back into the notebook for an exciting viewing experience. If you're looking at this on GitHub, it may or may not work! Maybe you should just run it yourself?" ] }, { "cell_type": "code", "execution_count": 331, "metadata": { "scrolled": false }, "outputs": [], "source": [ "np_x1,_=data.next_batch(BATCH_SIZE)\n", "np_x2,np_y1 = sess.run([p_x.mean(),y],{x:np_x1})" ] }, { "cell_type": "code", "execution_count": 332, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import matplotlib.animation as animation" ] }, { "cell_type": "code", "execution_count": 333, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def save_anim(data,figsize,filename):\n", " fig=plt.figure(figsize=(figsize[1]/10.0,figsize[0]/10.0))\n", " im = plt.imshow(data[0].reshape(figsize),cmap=plt.cm.gray,interpolation='none')\n", " plt.gca().set_axis_off()\n", " #fig.tight_layout()\n", " fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=None, hspace=None)\n", " def updatefig(t):\n", " im.set_array(data[t].reshape(figsize))\n", " return im,\n", " anim=animation.FuncAnimation(fig, updatefig, frames=100, interval=50, blit=True, repeat=True)\n", " Writer = animation.writers['imagemagick']\n", " writer = Writer(fps=1, metadata=dict(artist='Me'), bitrate=1800)\n", " anim.save(filename, writer=writer)\n", " return" ] }, { "cell_type": "code", "execution_count": 334, "metadata": {}, "outputs": [], "source": [ "save_anim(np_x1,(28,28),'x0.gif')\n", "save_anim(np_y1,(N,K),'y.gif')\n", "save_anim(np_x2,(28,28),'x1.gif')" ] }, { "cell_type": "code", "execution_count": 335, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from IPython.display import Image" ] }, { "cell_type": "code", "execution_count": 336, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 336, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Image(url='x0.gif')" ] }, { "cell_type": "code", "execution_count": 337, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 337, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Image(url='y.gif')" ] }, { "cell_type": "code", "execution_count": 338, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 338, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Image(url='x1.gif')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 4. Plot Training Curves" ] }, { "cell_type": "code", "execution_count": 339, "metadata": { "collapsed": true }, "outputs": [], "source": [ "dat=np.array(dat).T" ] }, { "cell_type": "code", "execution_count": 340, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ ":Layout\n", " .Curve.I :Curve [Iteration] (Temperature)\n", " .Curve.II :Curve [Iteration] (-ELBO)" ] }, "execution_count": 340, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%output size=200\n", "hv.Curve(zip(dat[0],dat[1]), kdims=['Iteration'], vdims=['Temperature'])+\\\n", "hv.Curve(zip(dat[0],dat[2]), kdims=['Iteration'], vdims=['-ELBO'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 5. Unconditional Generation\n", "\n", "This consists of sampling from the prior $p_\\theta(y)$ and passing it through the generative model.\n", "\n", "In the case of this nested model, we want to vary the first dimension of the latent representation first, to see if it encodes more of the variation in the image than other layers. We can do this by substituting a deterministic mask for the first dimension of the latent representation." ] }, { "cell_type": "code", "execution_count": 377, "metadata": { "collapsed": true }, "outputs": [], "source": [ "np_y = np.zeros((M*N,K))\n", "np_y[range(M*N),np.random.choice(K,M*N)] = 1\n", "np_y = np.reshape(np_y,[M,N,K])" ] }, { "cell_type": "code", "execution_count": 378, "metadata": { "collapsed": true }, "outputs": [], "source": [ "np_z = np.zeros((1,N))\n", "np_z[:,0] = 1.0\n", "np_z = np_z.repeat(M,axis=0)" ] }, { "cell_type": "code", "execution_count": 379, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ ":Layout\n", " .Image.I :Image [x,y] (z)\n", " .Image.II :Image [x,y] (z)" ] }, "execution_count": 379, "metadata": {}, "output_type": "execute_result" } ], "source": [ "hv.Image(np_y[:,0,:])+hv.Image(np_z.squeeze())" ] }, { "cell_type": "code", "execution_count": 380, "metadata": {}, "outputs": [], "source": [ "x_p=p_x.mean()\n", "np_x= sess.run(x_p,{y:np_y, z:np_z})" ] }, { "cell_type": "code", "execution_count": 381, "metadata": {}, "outputs": [], "source": [ "sort_idxs = np.argsort(np.argmax(np_y[:100,0,:],axis=1))" ] }, { "cell_type": "code", "execution_count": 382, "metadata": {}, "outputs": [], "source": [ "np_x = np_x[sort_idxs,:]\n", "np_x = np_x.reshape((10,10,28,28))\n", "# split into 10 (1,10,28,28) images, concat along columns -> 1,10,28,280\n", "np_x = np.concatenate(np.split(np_x,10,axis=0),axis=3)\n", "# split into 10 (1,1,28,280) images, concat along rows -> 1,1,280,280\n", "np_x = np.concatenate(np.split(np_x,10,axis=1),axis=2)\n", "x_img = np.squeeze(np_x)" ] }, { "cell_type": "code", "execution_count": 383, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ ":Layout\n", " .Image.I :Image [x,y] (z)\n", " .Image.II :Image [x,y] (z)" ] }, "execution_count": 383, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%output size=200\n", "hv.Image(np_y[np.argsort(np.argmax(np_y[:,0,:],axis=1)),0,:])+hv.Image(x_img)" ] }, { "cell_type": "code", "execution_count": 384, "metadata": {}, "outputs": [], "source": [ "def gen_with(indexes_included):\n", " np_z = np.zeros((1,N))\n", " for k in indexes_included:\n", " np_z[:,k] = 1.0\n", " # sort by the first index in the list\n", " sort_idxs = np.argsort(np.argmax(np_y[:100,indexes_included[0],:],axis=1))\n", " np_z = np_z.repeat(M,axis=0)\n", " x_p=p_x.mean()\n", " np_x = sess.run(x_p,{y:np_y, z:np_z})\n", " np_x = np_x[sort_idxs,:]\n", " np_x = np_x.reshape((10,10,28,28))\n", " # split into 10 (1,10,28,28) images, concat along columns -> 1,10,28,280\n", " np_x = np.concatenate(np.split(np_x,10,axis=0),axis=3)\n", " # split into 10 (1,1,28,280) images, concat along rows -> 1,1,280,280\n", " np_x = np.concatenate(np.split(np_x,10,axis=1),axis=2)\n", " return np.squeeze(np_x)" ] }, { "cell_type": "code", "execution_count": 385, "metadata": { "scrolled": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:root:LayoutPlot08763: None is empty, skipping subplot.\n" ] }, { "data": { "text/html": [ "