Skip to content

Instantly share code, notes, and snippets.

@gngdb
Last active October 9, 2018 12:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gngdb/7764fb0c3fa5f498393ed5d89dea9071 to your computer and use it in GitHub Desktop.
Save gngdb/7764fb0c3fa5f498393ed5d89dea9071 to your computer and use it in GitHub Desktop.
Few-shot relational reasoning (second notebook shows better results but having batchnorm problems with both).
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"$\n",
"\\newcommand{\\Dcal}{\\mathcal{D}}\n",
"\\newcommand{\\aB}{\\mathbf{a}}\n",
"\\newcommand{\\bB}{\\mathbf{b}}\n",
"\\newcommand{\\cB}{\\mathbf{c}}\n",
"\\newcommand{\\dB}{\\mathbf{d}}\n",
"\\newcommand{\\eB}{\\mathbf{e}}\n",
"\\newcommand{\\fB}{\\mathbf{f}}\n",
"\\newcommand{\\gB}{\\mathbf{g}}\n",
"\\newcommand{\\hB}{\\mathbf{h}}\n",
"\\newcommand{\\iB}{\\mathbf{i}}\n",
"\\newcommand{\\jB}{\\mathbf{j}}\n",
"\\newcommand{\\kB}{\\mathbf{k}}\n",
"\\newcommand{\\lB}{\\mathbf{l}}\n",
"\\newcommand{\\mB}{\\mathbf{m}}\n",
"\\newcommand{\\nB}{\\mathbf{n}}\n",
"\\newcommand{\\oB}{\\mathbf{o}}\n",
"\\newcommand{\\pB}{\\mathbf{p}}\n",
"\\newcommand{\\qB}{\\mathbf{q}}\n",
"\\newcommand{\\rB}{\\mathbf{r}}\n",
"\\newcommand{\\sB}{\\mathbf{s}}\n",
"\\newcommand{\\tB}{\\mathbf{t}}\n",
"\\newcommand{\\uB}{\\mathbf{u}}\n",
"\\newcommand{\\vB}{\\mathbf{v}}\n",
"\\newcommand{\\wB}{\\mathbf{w}}\n",
"\\newcommand{\\xB}{\\mathbf{x}}\n",
"\\newcommand{\\yB}{\\mathbf{y}}\n",
"\\newcommand{\\zB}{\\mathbf{z}}\n",
"$\n",
"$\n",
"\\newcommand{\\AB}{\\mathbf{A}}\n",
"\\newcommand{\\BB}{\\mathbf{B}}\n",
"\\newcommand{\\CB}{\\mathbf{C}}\n",
"\\newcommand{\\DB}{\\mathbf{D}}\n",
"\\newcommand{\\EB}{\\mathbf{E}}\n",
"\\newcommand{\\FB}{\\mathbf{F}}\n",
"\\newcommand{\\GB}{\\mathbf{G}}\n",
"\\newcommand{\\HB}{\\mathbf{H}}\n",
"\\newcommand{\\IB}{\\mathbf{I}}\n",
"\\newcommand{\\JB}{\\mathbf{J}}\n",
"\\newcommand{\\KB}{\\mathbf{K}}\n",
"\\newcommand{\\LB}{\\mathbf{L}}\n",
"\\newcommand{\\MB}{\\mathbf{M}}\n",
"\\newcommand{\\NB}{\\mathbf{N}}\n",
"\\newcommand{\\OB}{\\mathbf{O}}\n",
"\\newcommand{\\PB}{\\mathbf{P}}\n",
"\\newcommand{\\QB}{\\mathbf{Q}}\n",
"\\newcommand{\\RB}{\\mathbf{R}}\n",
"\\newcommand{\\SB}{\\mathbf{S}}\n",
"\\newcommand{\\TB}{\\mathbf{T}}\n",
"\\newcommand{\\UB}{\\mathbf{U}}\n",
"\\newcommand{\\VB}{\\mathbf{V}}\n",
"\\newcommand{\\WB}{\\mathbf{W}}\n",
"\\newcommand{\\XB}{\\mathbf{X}}\n",
"\\newcommand{\\YB}{\\mathbf{Y}}\n",
"\\newcommand{\\ZB}{\\mathbf{Z}}\n",
"$\n",
"$\n",
"\\newcommand{\\alphaB}{\\boldsymbol{\\alpha}}\n",
"\\newcommand{\\betaB}{\\boldsymbol{\\beta}}\n",
"\\newcommand{\\gammaB}{\\boldsymbol{\\gamma}}\n",
"\\newcommand{\\deltaB}{\\boldsymbol{\\delta}}\n",
"\\newcommand{\\epsilonB}{\\boldsymbol{\\epsilon}}\n",
"\\newcommand{\\varepsilonB}{\\boldsymbol{\\varepsilon}}\n",
"\\newcommand{\\zetaB}{\\boldsymbol{\\zeta}}\n",
"\\newcommand{\\etaB}{\\boldsymbol{\\eta}}\n",
"\\newcommand{\\thetaB}{\\boldsymbol{\\theta}}\n",
"\\newcommand{\\varthetaB}{\\boldsymbol{\\vartheta}}\n",
"\\newcommand{\\iotaB}{\\boldsymbol{\\iota}}\n",
"\\newcommand{\\kappaB}{\\boldsymbol{\\kappa}}\n",
"\\newcommand{\\lambdaB}{\\boldsymbol{\\lambda}}\n",
"\\newcommand{\\muB}{\\boldsymbol{\\mu}}\n",
"\\newcommand{\\nuB}{\\boldsymbol{\\nu}}\n",
"\\newcommand{\\xiB}{\\boldsymbol{\\xi}}\n",
"\\newcommand{\\piB}{\\boldsymbol{\\pi}}\n",
"\\newcommand{\\varpiB}{\\boldsymbol{\\varpi}}\n",
"\\newcommand{\\rhoB}{\\boldsymbol{\\rho}}\n",
"\\newcommand{\\varrhoB}{\\boldsymbol{\\varrho}}\n",
"\\newcommand{\\sigmaB}{\\boldsymbol{\\sigma}}\n",
"\\newcommand{\\varsigmaB}{\\boldsymbol{\\varsigma}}\n",
"\\newcommand{\\tauB}{\\boldsymbol{\\tau}}\n",
"\\newcommand{\\upsilonB}{\\boldsymbol{\\upsilon}}\n",
"\\newcommand{\\phiB}{\\boldsymbol{\\phi}}\n",
"\\newcommand{\\varphiB}{\\boldsymbol{\\varphi}}\n",
"\\newcommand{\\chiB}{\\boldsymbol{\\chi}}\n",
"\\newcommand{\\psiB}{\\boldsymbol{\\psi}}\n",
"\\newcommand{\\omegaB}{\\boldsymbol{\\omega}}\n",
"$\n",
"$\n",
"\\newcommand{\\GammaB}{\\boldsymbol{\\Gamma}}\n",
"\\newcommand{\\DeltaB}{\\boldsymbol{\\Delta}}\n",
"\\newcommand{\\ThetaB}{\\boldsymbol{\\Theta}}\n",
"\\newcommand{\\LambdaB}{\\boldsymbol{\\Lambda}}\n",
"\\newcommand{\\XiB}{\\boldsymbol{\\Xi}}\n",
"\\newcommand{\\PiB}{\\boldsymbol{\\Pi}}\n",
"\\newcommand{\\SigmaB}{\\boldsymbol{\\Sigma}}\n",
"\\newcommand{\\UpsilonB}{\\boldsymbol{\\Upsilon}}\n",
"\\newcommand{\\PhiB}{\\boldsymbol{\\Phi}}\n",
"\\newcommand{\\PsiB}{\\boldsymbol{\\Psi}}\n",
"\\newcommand{\\OmegaB}{\\boldsymbol{\\Omega}}\n",
"$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Recent work in one shot and few-shot learning (matching/prototypical networks) have focused on learning an embedding for a set of data where we can compare datapoints with some kind of distance measure. Prototypical networks do this in the simplest possible way; learn an embedding where the mean of classes are separated in euclidean distance. Then, to classify, just look at the relative distances from the means of classes in that minibatch.\n",
"\n",
"We are necessarily throwing away information when doing this, because we're taking the mean of a bunch of representations, instead of individually reasoning about the relationships between the query point and all points in the support set. Basically, instead of having a function like the following in the prototypical paper:\n",
"\n",
"$$\n",
"p_{\\thetaB}(y=k | \\xB) = \\frac{-|| f(_\\thetaB) (\\xB) - \\cB_k ||^2}{\\sum_k' -|| f(_\\thetaB) (\\xB) - \\cB_{k'} ||^2} \\\\\n",
"\\cB_k = \\frac{1}{|S(k)|} \\sum_{(\\xB_i, y_i) \\in S(k)} f_{\\thetaB} (\\xB_i)\n",
"$$\n",
"\n",
"We could have a function that explicitly reasons about the relationships, using a relational network:\n",
"\n",
"$$\n",
"RN (O) = f_{\\phiB} \\left( \\sum_{i,j} g_{\\thetaB}(o_i, o_j) \\right)\n",
"$$\n",
"\n",
"One easy way to make this output a logit for each class would be to separate the support by class and sum over each class, producing $K$ (for $K$ classes) logit outputs:\n",
"\n",
"$$\n",
"p(y=k | \\xB) \\propto f_{\\phiB} \\left( \\sum_{i \\in S(k), j \\in Q(k)} g_{\\thetaB}(\\xB_i^s, \\xB_j^q) \\right)\n",
"$$\n",
"\n",
"But, the degree to which you believe that a datapoint belongs to a certain class may depend on you knowing that another datapoint certainly belongs to that class. Let's be more clear about this, in order to make a decision about what class something is, if you give me a support set of images, I'm going to be comparing:\n",
"\n",
"1. Images in the support set with each other arbitrarily.\n",
"2. The query set image with images in each class in the support set (I would separate the images by class, and look at the differences).\n",
"3. The query set image with images in the support set arbitrarily.\n",
"\n",
"The query image should act as the \"question\" does in the relational network paper, conditioning universally before the sum operation. And, we should use a cubic nested sum (using $\\SB$ to mean the entire support set).\n",
"\n",
"$$\n",
"p(y=k | \\xB) \\propto f_{\\phiB} \\left( \\sum_{k \\in S(k)} \\sum_{i,j \\in \\SB} g_{\\thetaB}(\\xB_i^s, \\xB_j^s, \\xB_k^s, \\xB_q) \\right)\n",
"$$\n",
"\n",
"We get a cubic memory explosion, but typically the latent activations are not high-dimensional, and by its nature few-shot learning should have a fairly small number of items in the support set.\n",
"\n",
"**Why should this perform better than other few-shot learning methods?** This ought to perform better because it can actually reason about the relationships between individual datapoints and others. We can imagine that a prototypical network could be confounded by an outlier, because it will move the $\\cB_k$ mean away from the \"correct\" location for that class, making it difficult to classify. Our network could learn to ignore an outlier if it sees a close correspondance between an example in the query and support set.\n",
"\n",
"Going to be training this with the same hyperparameters that train a prototypical network to 99% on omniglot 5-way one shot classification, but I may run alternative experiments. Would just like to see that it is possible to optimise a network like the one I've described above.\n",
"\n",
"## What Makes This Simpler?\n",
"\n",
"We're going to avoid comparing the images in the support set with themselves, because it's expensive. So, our function becomes:\n",
"\n",
"$$\n",
"p(y=k | \\xB) \\propto f_{\\phiB} \\left( \\sum_{k \\in S(k)} \\sum_{i \\in \\SB} g_{\\thetaB}(\\xB_i^s \\xB_k^s, \\xB_q) \\right)\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Imports\n",
"\n",
"Re-using imports from another experiment, all of them might not be required."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"try:\n",
" import dumb_tf_loadbalance\n",
" dumb_tf_loadbalance.choose_gpu()\n",
"except ImportError:\n",
" pass\n",
"\n",
"import time\n",
"import os\n",
"import tensorflow as tf\n",
"from tensorflow.contrib.framework.python.ops import arg_scope\n",
"import numpy as np\n",
"from tqdm import tqdm_notebook as tqdm\n",
"\n",
"from tensorflow.examples.tutorials.mnist import input_data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"flags = tf.app.flags\n",
"FLAGS = flags.FLAGS\n",
"run_tag = os.environ.get('RUN_TAG', '')\n",
"flags.DEFINE_string('summary_dir', \"/tmp/relational/\",\n",
" 'Summaries directory')\n",
"flags.DEFINE_string('checkpoint_file', None,\n",
" 'Name of file, found in env variable CHECKPOINT_DIR, to load.')\n",
"flags.DEFINE_bool(\"stochastic\", False, \"true to activate sampling in prediction for regularisation\")\n",
"flags.DEFINE_bool(\"weightnorm\", False, \"True to use weightnorm instead of batchnorm.\")\n",
"flags.DEFINE_bool(\"pseudocount\", False, \"true to estimate appropriate pseudocount from support\")\n",
"flags.DEFINE_bool(\"unbalanced\", False, \"True to train and test on unbalanced \"\n",
" \"classes (number of samples in support\"\n",
" \" and query will equal \"\n",
" \"y_dim*n_samples_per_class in this case.\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# clean it up if we've got the wrong place\n",
"if tf.gfile.Exists(FLAGS.summary_dir):\n",
" tf.gfile.DeleteRecursively(FLAGS.summary_dir)\n",
"tf.gfile.MakeDirs(FLAGS.summary_dir)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# set some variables\n",
"cur_time = time.time()\n",
"mb_dim = 2 #training examples per minibatch\n",
"x_dim = 28 #size of one side of square image\n",
"y_dim = 5 #possible classes\n",
"n_samples_per_class = 2 #samples of each class\n",
"n_samples = y_dim*n_samples_per_class #total number of labeled samples"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Loading Omniglot\n",
"\n",
"And defining an iterator:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import os\n",
"import numpy as np\n",
"\n",
"def load_data():\n",
" data = np.load(os.path.join(os.environ['OMNIGLOT_DIR'],'data.npy'))\n",
" data = np.reshape(data,[-1,20,28,28]) #each of the 1600 classes has 20 examples\n",
" data = np.random.permutation(data)\n",
" train_data = data[:1200,:,:,:]\n",
" val_data = data[1200:1200+122,:,:,:]\n",
" test_data = data[1200+122:,:,:,:]\n",
" return train_data, val_data, test_data\n",
"\n",
"# define a dataset iterator\n",
"def get_minibatch(cur_data, n_samples_per_class, y_dim):\n",
" n_samples = y_dim*n_samples_per_class \n",
" mb_x_i = np.zeros((mb_dim,n_samples,x_dim,x_dim,1))\n",
" mb_y_i = np.zeros((mb_dim,n_samples))\n",
" mb_x_hat = np.zeros((mb_dim,n_samples,x_dim,x_dim,1))\n",
" mb_y_hat = np.zeros((mb_dim,n_samples))\n",
" for i in range(mb_dim):\n",
" ind = 0\n",
" pinds = np.random.permutation(n_samples)\n",
" qpinds = np.random.permutation(n_samples)\n",
" if FLAGS.unbalanced:\n",
" # without replacement, sample a number of classes to work with\n",
" classes = np.random.choice(cur_data.shape[0], y_dim, replace=False)\n",
" # sample from this class set\n",
" # (for query and support)\n",
" example_classes = np.random.choice(y_dim, n_samples, replace=True)\n",
" # sample example indexes to go with this set of classes\n",
" example_inds = np.random.choice(cur_data.shape[1]//2, n_samples, replace=True)\n",
" qexample_inds = np.random.choice(cur_data.shape[1]//2, n_samples, replace=True)\n",
" qexample_inds = cur_data.shape[1]//2 + qexample_inds\n",
" for ex_class, eind, qeind in zip(example_classes, example_inds, qexample_inds):\n",
" cur_class = classes[ex_class]\n",
" mb_x_i[i, pinds[ind],:,:,0] = np.rot90(cur_data[cur_class][eind],np.random.randint(4))\n",
" mb_y_i[i,pinds[ind]] = ex_class\n",
" mb_x_hat[i, qpinds[ind],:,:,0] = np.rot90(cur_data[cur_class][qeind],np.random.randint(4))\n",
" mb_y_hat[i, qpinds[ind]] = ex_class\n",
" ind += 1\n",
" else:\n",
" classes = np.random.choice(cur_data.shape[0],y_dim,False)\n",
" for j,cur_class in enumerate(classes): #each class\n",
" example_inds = np.random.choice(cur_data.shape[1],2*n_samples_per_class,False)\n",
" example_inds = example_inds.reshape(2,-1)\n",
" for eind, qeind in zip(example_inds[0], example_inds[1]):\n",
" mb_x_i[i,pinds[ind],:,:,0] = np.rot90(cur_data[cur_class][eind],np.random.randint(4))\n",
" mb_y_i[i,pinds[ind]] = j\n",
" mb_x_hat[i,qpinds[ind],:,:,0] = np.rot90(cur_data[cur_class][qeind],np.random.randint(4))\n",
" mb_y_hat[i,qpinds[ind]] = j\n",
" ind +=1\n",
" return mb_x_i,mb_y_i,mb_x_hat,mb_y_hat"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"train_data, val_data, test_data = load_data()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Defining the Network\n",
"\n",
"We're going to use the traditional four layer convolutional network used in [\"Prototypical Networks for Few-shot Learning\"][proto]. We also have two functions for the MLPs $g_{\\phiB}$ and $f_{\\thetaB}$.\n",
"\n",
"[proto]: https://arxiv.org/abs/1703.05175"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# %load ../oneshot/model.py\n",
"# Model building functions for prototypical or class conditional models\n",
"\n",
"import tensorflow as tf\n",
"import tensorflow.contrib.slim as slim\n",
"\n",
"n_latent = 32\n",
"\n",
"def conv2d_bn_relu_maxpool(x, final=False, phase=True):\n",
" \"\"\"Encoder with batchnorm using slim's implementation.\"\"\"\n",
" x = slim.conv2d(x, 64, [3,3], activation_fn=None)\n",
" x = slim.batch_norm(x, is_training=phase, updates_collections=None, center=True, scale=True)\n",
" #if not final:\n",
" x = tf.nn.relu(x)\n",
" return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='VALID')\n",
"\n",
"def encoder_bn(x, is_train, scope='encoder', reuse=False):\n",
" with tf.variable_scope(scope, reuse=reuse) as varscope:\n",
" #if reuse: varscope.reuse_variables()\n",
" net = x\n",
" for i in range(3):\n",
" net = conv2d_bn_relu_maxpool(net, phase=is_train)\n",
" net = conv2d_bn_relu_maxpool(net, phase=is_train, final=True)\n",
" net = tf.contrib.layers.flatten(net)\n",
" net = slim.fully_connected(net, n_latent)\n",
" return net\n",
"\n",
"def g_phi(x, n_latent, is_train, reuse=False, scope='g_phi'):\n",
" with tf.variable_scope(scope, reuse=reuse) as varscope:\n",
" #if reuse: varscope.reuse_variables()\n",
" net = slim.fully_connected(x, n_latent, activation_fn=None)\n",
" net = slim.batch_norm(net, is_training=phase, updates_collections=None, center=True, scale=True)\n",
" net = tf.nn.relu(net)\n",
" net = slim.fully_connected(net, n_latent, activation_fn=None)\n",
" net = slim.batch_norm(net, is_training=phase, updates_collections=None, center=True, scale=True)\n",
" net = tf.nn.relu(net)\n",
" return net\n",
"\n",
"def f_theta(x, n_latent, is_train, reuse=False, scope='f_theta'):\n",
" with tf.variable_scope(scope, reuse=reuse) as varscope:\n",
" #if reuse: varscope.reuse_variables()\n",
" net = slim.fully_connected(x, n_latent, activation_fn=None)\n",
" net = slim.batch_norm(net, is_training=phase, updates_collections=None, center=True, scale=True)\n",
" net = tf.nn.relu(net)\n",
" net = slim.fully_connected(net, 1, activation_fn=None)\n",
" net = slim.batch_norm(net, is_training=phase, updates_collections=None, center=True, scale=True)\n",
" return net"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Defining the Loss Functions\n",
"\n",
"Now we will incorporate the theory defined above. In training iterations we're going to be propagating through a probability that we use to sample Bernoulli distributed random vectors. At test time, we'll be sampling as Bernoulli variables, but at test time we will be *questionably* using the mean values (ie the probabilities output by our network).\n",
"\n",
"## Placeholders\n",
"\n",
"We're going to have placeholders for a support set and a query set. The support will define the statistics that we use for classification on the query set, as in few-shot learning. *Unlike in few-shot learning* we have to also supply priors $\\alphaB$ and $\\betaB$."
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"tf.reset_default_graph()"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"x = tf.placeholder(tf.float32, [mb_dim, None, x_dim, x_dim, 1],\n",
" name='SupportFeatures')\n",
"# placeholder for the labels\n",
"y = tf.placeholder(tf.int32, [mb_dim, None], name='SupportLabels')\n",
"temperature = tf.placeholder(tf.float32, (), name='Temperature')\n",
"# and pass query set through network as well\n",
"# query input placeholders\n",
"x_hat = tf.placeholder(tf.float32, [mb_dim, None, x_dim, x_dim, 1], name='QueryFeatures')\n",
"y_hat = tf.placeholder(tf.int32, [mb_dim, None], name='QueryLabels')\n",
"# and a placeholder for train/test phase\n",
"phase = tf.placeholder(tf.bool, None, 'Phase')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Making the network\n",
"\n",
"We have to pass one example through the network to initialise the variables, because we're going to be sharing parameters a lot, calling the function for losses on every example independently in the training batch."
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"encoder = encoder_bn\n",
"# pass through once for initialisation\n",
"init = encoder(x[0], is_train=phase)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"collapsed": true,
"scrolled": false
},
"outputs": [],
"source": [
"#mb_x_i,mb_y_i,mb_x_hat,mb_y_hat = get_minibatch(train_data, 2, y_dim)\n",
"#fd = {x_hat: mb_x_hat, y_hat: mb_y_hat, x: mb_x_i, y: mb_y_i,\n",
"# phase:True}\n",
"#first=True\n",
"#def get_losses(x, y, x_hat, y_hat, fd=fd):\n",
"def get_losses(x, y, x_hat, y_hat):\n",
" # encode support set\n",
" support_embedded = encoder(x, is_train=phase, reuse=True)\n",
"\n",
" # pass x_hat through our inference network\n",
" query_embedded = encoder(x_hat, is_train=phase, reuse=True)\n",
"\n",
" # we will need y in a onehot encoding\n",
" depth = tf.reduce_max(y)+1\n",
" oh_y = tf.one_hot(y, depth)\n",
"\n",
" ######################\n",
" # Relational Network #\n",
" ######################\n",
" # count the number of samples we're dealing with\n",
" n_samples = tf.shape(x)[0]\n",
" # we're going to need more dimensions\n",
" def expand_more_dims(x, dims):\n",
" for d in dims:\n",
" x = tf.expand_dims(x, d)\n",
" return x\n",
" support_i = tf.expand_dims(support_embedded, 0) # (1, n_samples, n_latent)\n",
" support_i = tf.pad(support_i, [[0,0],[0,0],[0,n_latent*2]]) # (1, n_samples, 3*n_latent)\n",
" support_k = tf.expand_dims(support_embedded, 1) # (n_samples, 1, n_latent)\n",
" support_k = tf.pad(support_k, [[0,0],[0,0],[n_latent,n_latent]]) # (n_samples, 1, 3*n_latent)\n",
" # add them all up (easier than concat); shape (n_samples, n_samples, n_latent)\n",
" support_ik = support_i+support_k # (n_samples, n_samples, n_samples, 4*n_latent)\n",
" # concatenate on the query points\n",
" query_expanded = expand_more_dims(query_embedded, [0,0])\n",
" query_expanded = tf.pad(query_expanded, [[0,0],[0,0],[0,0],[2*n_latent,0]])\n",
" support_ijk = tf.expand_dims(support_ik, 2)\n",
" support_concat = support_ijk+query_expanded # (n_samples, n_samples, n_samples, 4*n_latent)\n",
"\n",
" # reshape and pass through comparison function g_phi\n",
" g_phi_inp = tf.reshape(support_concat, (-1, 3*n_latent))\n",
" g_phi_op = g_phi(g_phi_inp, n_latent, is_train=phase, reuse=not first)\n",
" # then reshape back to original shape\n",
" g_phi_rs = tf.reshape(g_phi_op, (n_samples, n_samples, n_samples, n_latent))\n",
"\n",
" # zero out values that are not in the class on k ; (n_samples, n_samples, n_classes, n_samples, n_latent)\n",
" class_sep = tf.expand_dims(g_phi_rs, 2)*expand_more_dims(oh_y, [0,-1,-1])\n",
"\n",
" # sum out over i,j and k; (n_classes, n_samples, n_latent)\n",
" comparison_sum = tf.reduce_sum(class_sep, axis=[0,1])\n",
"\n",
" # reshape and pass through f_theta function\n",
" f_theta_inp = tf.reshape(comparison_sum, (-1, n_latent))\n",
" f_theta_op = f_theta(f_theta_inp, n_latent, is_train=phase, reuse=not first)\n",
" f_theta_rs = tf.reshape(f_theta_op, (depth, n_samples, 1))\n",
" \n",
" # squeeze and transpose for softmax\n",
" logits = tf.transpose(tf.squeeze(f_theta_rs))\n",
" \n",
" # need this one hot encoded with same depth\n",
" query_oh_y = tf.one_hot(y_hat, depth)\n",
" def query_log_loss_acc(logits):\n",
" # then make predictions, to calculate accuracy and loss\n",
" predictions = tf.nn.softmax(logits=logits)\n",
"\n",
" log_p_yGz = tf.reduce_sum(query_oh_y*tf.log(predictions+1e-8),\n",
" axis=1)\n",
"\n",
" correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(query_oh_y, 1))\n",
" accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
" return log_p_yGz, accuracy\n",
" query_log_p_yGz, query_accuracy = query_log_loss_acc(logits)\n",
" mean_log_loss = -tf.reduce_mean(query_log_p_yGz, axis=0)\n",
"\n",
" #sess = tf.InteractiveSession()\n",
" #sess.run(tf.global_variables_initializer(), fd)\n",
" #fd = {x_hat: mb_x_hat[0], y_hat: mb_y_hat[0], x: mb_x_i[0], y: mb_y_i[0],\n",
" # phase:True}\n",
" #import pdb\n",
" #pdb.set_trace()\n",
" \n",
" return query_accuracy, mean_log_loss\n",
"#get_losses(x[0], y[0], x_hat[0], y_hat[0])"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
"collapsed": true,
"scrolled": false
},
"outputs": [],
"source": [
"# iterate over independent classification problems\n",
"# if I try hard I may be able to make this more unreadable\n",
"losses = []\n",
"first = True\n",
"for p in range(mb_dim):\n",
" losses.append(get_losses(x[p], y[p], x_hat[p], y_hat[p]))\n",
" first=False\n",
"query_accuracy, mean_log_loss = \\\n",
" map(lambda x: tf.stack(x), zip(*losses))\n",
"query_accuracy = tf.reduce_mean(query_accuracy, axis=0)\n",
"mean_log_loss = tf.reduce_mean(mean_log_loss, axis=0)\n",
"learning_rate = 0.001\n",
"tf_lr = tf.placeholder(tf.float32, None, name='learning_rate')\n",
"tf.summary.scalar('learning_rate', tf_lr) # to track learning rate passed\n",
"all_params = tf.trainable_variables()\n",
"optimizer = \\\n",
" tf.train.AdamOptimizer(learning_rate=tf_lr).minimize(mean_log_loss, var_list=all_params)\n",
"# intialisation op\n",
"init = tf.global_variables_initializer() \n",
"# summary op\n",
"merged = tf.summary.merge_all()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training the Network\n",
"\n",
"Training the encoder on the task of few-shot learning on omniglot."
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "93247227dbcb4d64b79710d22cd55cc6"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"initializing the model...\n",
"Step1000\n",
"\tLoss: 1.362\tAccuracy: 0.547\tTest Loss: 1.605\tTest Accuracy: 0.330\n",
"Step2000\n",
"\tLoss: 0.643\tAccuracy: 0.782\tTest Loss: 1.557\tTest Accuracy: 0.310\n",
"Step3000\n",
"\tLoss: 0.493\tAccuracy: 0.832\tTest Loss: 1.551\tTest Accuracy: 0.400\n",
"Step4000\n",
"\tLoss: 0.409\tAccuracy: 0.861\tTest Loss: 1.599\tTest Accuracy: 0.260\n",
"Step5000\n",
"\tLoss: 0.355\tAccuracy: 0.879\tTest Loss: 1.518\tTest Accuracy: 0.490\n",
"Step6000\n",
"\tLoss: 0.313\tAccuracy: 0.894\tTest Loss: 1.507\tTest Accuracy: 0.370\n",
"Step7000\n",
"\tLoss: 0.294\tAccuracy: 0.899\tTest Loss: 1.427\tTest Accuracy: 0.570\n",
"Step8000\n",
"\tLoss: 0.268\tAccuracy: 0.908\tTest Loss: 1.448\tTest Accuracy: 0.560\n",
"Step9000\n",
"\tLoss: 0.263\tAccuracy: 0.911\tTest Loss: 1.456\tTest Accuracy: 0.520\n",
"Step10000\n",
"\tLoss: 0.251\tAccuracy: 0.914\tTest Loss: 1.458\tTest Accuracy: 0.480\n",
"Step11000\n",
"\tLoss: 0.237\tAccuracy: 0.918\tTest Loss: 1.514\tTest Accuracy: 0.330\n",
"Step12000\n",
"\tLoss: 0.230\tAccuracy: 0.921\tTest Loss: 1.237\tTest Accuracy: 0.710\n",
"Step13000\n",
"\tLoss: 0.226\tAccuracy: 0.923\tTest Loss: 1.400\tTest Accuracy: 0.550\n",
"Step14000\n",
"\tLoss: 0.220\tAccuracy: 0.924\tTest Loss: 1.379\tTest Accuracy: 0.550\n",
"Step15000\n",
"\tLoss: 0.213\tAccuracy: 0.927\tTest Loss: 1.355\tTest Accuracy: 0.510\n",
"Step16000\n",
"\tLoss: 0.197\tAccuracy: 0.931\tTest Loss: 1.193\tTest Accuracy: 0.810\n",
"Step17000\n",
"\tLoss: 0.199\tAccuracy: 0.930\tTest Loss: 1.548\tTest Accuracy: 0.360\n",
"Step18000\n",
"\tLoss: 0.191\tAccuracy: 0.934\tTest Loss: 1.443\tTest Accuracy: 0.430\n",
"Step19000\n",
"\tLoss: 0.185\tAccuracy: 0.935\tTest Loss: 1.388\tTest Accuracy: 0.620\n",
"Step20000\n",
"\tLoss: 0.182\tAccuracy: 0.936\tTest Loss: 1.201\tTest Accuracy: 0.860\n",
"Step21000\n",
"\tLoss: 0.174\tAccuracy: 0.940\tTest Loss: 1.421\tTest Accuracy: 0.490\n",
"Step22000\n",
"\tLoss: 0.169\tAccuracy: 0.941\tTest Loss: 1.163\tTest Accuracy: 0.880\n",
"Step23000\n",
"\tLoss: 0.166\tAccuracy: 0.942\tTest Loss: 1.232\tTest Accuracy: 0.800\n",
"Step24000\n",
"\tLoss: 0.159\tAccuracy: 0.944\tTest Loss: 1.135\tTest Accuracy: 0.820\n",
"Step25000\n",
"\tLoss: 0.158\tAccuracy: 0.945\tTest Loss: 1.252\tTest Accuracy: 0.660\n",
"Step26000\n",
"\tLoss: 0.152\tAccuracy: 0.947\tTest Loss: 1.208\tTest Accuracy: 0.820\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-48-825312211e62>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0;31m# train for one step\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 52\u001b[0;31m \u001b[0msummary\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtr_acc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtr_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mmerged\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquery_accuracy\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmean_log_loss\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfd\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 53\u001b[0m \u001b[0;31m# accumulate for averages and write immediate summary to logs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0mwriter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_summary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msummary\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexample_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/disk/scratch/deepvenv/miniconda2/envs/oneshot/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 776\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 777\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 778\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 779\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 780\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/disk/scratch/deepvenv/miniconda2/envs/oneshot/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 980\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 981\u001b[0m results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m--> 982\u001b[0;31m feed_dict_string, options, run_metadata)\n\u001b[0m\u001b[1;32m 983\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 984\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/disk/scratch/deepvenv/miniconda2/envs/oneshot/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1030\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1031\u001b[0m return self._do_call(_run_fn, self._session, feed_dict, fetch_list,\n\u001b[0;32m-> 1032\u001b[0;31m target_list, options, run_metadata)\n\u001b[0m\u001b[1;32m 1033\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1034\u001b[0m return self._do_call(_prun_fn, self._session, handle, feed_dict,\n",
"\u001b[0;32m/disk/scratch/deepvenv/miniconda2/envs/oneshot/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m 1037\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1038\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1039\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1040\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1041\u001b[0m \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/disk/scratch/deepvenv/miniconda2/envs/oneshot/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(session, feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m 1019\u001b[0m return tf_session.TF_Run(session, options,\n\u001b[1;32m 1020\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1021\u001b[0;31m status, run_metadata)\n\u001b[0m\u001b[1;32m 1022\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1023\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msession\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"sess = tf.InteractiveSession()\n",
"\n",
"# initialise\n",
"CKPT_FILE = None\n",
"writer = tf.summary.FileWriter(FLAGS.summary_dir,graph=sess.graph)\n",
"saver = tf.train.Saver()\n",
"NUM_ITERS = 40000\n",
"wm = 0.0\n",
"example_idx, lr_schedule = 0, None\n",
"lr_schedule = learning_rate*0.5*(np.cos(np.linspace(0.0, np.pi, NUM_ITERS))+1.0)\n",
"# best loss we've seen at start is just naive guessing\n",
"best_test_loss = 1.6\n",
"test_loss_avg = 1.7\n",
"\n",
"# initialise rng\n",
"rng = np.random.RandomState(42)\n",
"\n",
"# initialise accumulators\n",
"tr_acc_avg, tr_loss_avg = 0., 0.\n",
"tr_n_batches = 0.\n",
"progress = tqdm(range(NUM_ITERS))\n",
"for i in progress:\n",
" # init\n",
" if i == 0:\n",
" mb_x_i,mb_y_i,mb_x_hat,mb_y_hat = get_minibatch(train_data, 2, y_dim)\n",
" fd = {x_hat: mb_x_hat, y_hat: mb_y_hat, x: mb_x_i,\n",
" y: mb_y_i,\n",
" phase:True}\n",
" sess.run(init, fd)\n",
" print('initializing the model...')\n",
" if CKPT_FILE is not None:\n",
" print('restoring parameters from %s'%CKPT_FILE)\n",
" saver.restore(sess, CKPT_FILE)\n",
"\n",
" # on each epoch, to write to summaries\n",
" def write_scalar_summary(name, scalar_value):\n",
" summary = tf.Summary()\n",
" summary.value.add(tag=name, simple_value=scalar_value)\n",
" writer.add_summary(summary, i)\n",
" return None\n",
"\n",
" # get a training minibatch\n",
" #tr_batch_size = rng.randint(1,4)\n",
" tr_batch_size = 2\n",
" mb_x_i,mb_y_i,mb_x_hat,mb_y_hat = get_minibatch(train_data, tr_batch_size, 4*y_dim)\n",
" np_lr = lr_schedule[i]\n",
" fd = {x_hat: mb_x_hat, y_hat: mb_y_hat, x: mb_x_i, y: mb_y_i,\n",
" tf_lr: np_lr,\n",
" phase:True}\n",
"\n",
" # train for one step\n",
" summary, _, tr_acc, tr_loss = \\\n",
" sess.run([merged, optimizer, query_accuracy, mean_log_loss], fd)\n",
" # accumulate for averages and write immediate summary to logs\n",
" writer.add_summary(summary, example_idx)\n",
" tr_acc_avg += tr_acc\n",
" tr_loss_avg += tr_loss\n",
" tr_n_batches += 1.\n",
" example_idx += 1\n",
"\n",
" if (i+1) % 1000 == 0:\n",
" # run on holdout set\n",
" test_acc_avg, test_loss_avg = 0., 0.\n",
" n_batches = 0.\n",
" for j in range(10):\n",
" mb_x_i,mb_y_i,mb_x_hat,mb_y_hat = get_minibatch(val_data, 1, y_dim)\n",
" fd = {x_hat: mb_x_hat, y_hat: mb_y_hat, x: mb_x_i, y: mb_y_i,\n",
" tf_lr: np_lr,\n",
" phase:False}\n",
" test_acc, test_loss = sess.run([query_accuracy,\n",
" mean_log_loss], fd)\n",
" test_acc_avg += test_acc\n",
" test_loss_avg += test_loss\n",
" n_batches += 1.\n",
"\n",
" test_acc_avg, test_loss_avg = map(lambda x:x/n_batches,\n",
" [test_acc_avg, test_loss_avg])\n",
" tr_acc_avg, tr_loss_avg = map(lambda x:x/tr_n_batches,\n",
" [tr_acc_avg, tr_loss_avg])\n",
" desc = 'Loss: %0.3f\\tAccuracy: %0.3f'%(tr_loss_avg, tr_acc_avg)\n",
" desc +='\\tTest Loss: %0.3f\\tTest Accuracy: %0.3f'%(test_loss_avg, test_acc_avg)\n",
" print(\"Step%i\\n\\t\"%(i+1)+desc)\n",
" # then log these results\n",
" write_scalar_summary(\"Test Loss\", test_loss_avg)\n",
" write_scalar_summary(\"Test Accuracy\", test_acc_avg)\n",
"\n",
" # training traces\n",
" write_scalar_summary(\"Train Loss\", tr_loss_avg)\n",
" write_scalar_summary(\"Train Accuracy\", tr_acc_avg)\n",
"\n",
" # reset accumulators\n",
" tr_acc_avg, tr_loss_avg = 0., 0.\n",
" tr_n_batches = 0.\n",
"\n",
" # show training loss and accuracy on progress bar\n",
" progress.set_postfix(loss=tr_loss, acc=tr_acc)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Appears to be a batchnorm problem, because it disappears if we enable train phase batchnorm at test time. Probably due to the variable batch size in the $f_{\\thetaB}$ and $g_{\\phiB}$ networks at training and test time."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a07d547e50884256a8a1ee794dd3627f"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Step1000\n",
"\tLoss: 0.187\tAccuracy: 0.935\tTest Loss: 0.093\tTest Accuracy: 0.950\n"
]
}
],
"source": [
"# initialise accumulators\n",
"tr_acc_avg, tr_loss_avg = 0., 0.\n",
"tr_n_batches = 0.\n",
"progress = tqdm(range(NUM_ITERS))\n",
"for i in progress:\n",
" # on each epoch, to write to summaries\n",
" def write_scalar_summary(name, scalar_value):\n",
" summary = tf.Summary()\n",
" summary.value.add(tag=name, simple_value=scalar_value)\n",
" writer.add_summary(summary, i)\n",
" return None\n",
"\n",
" # get a training minibatch\n",
" #tr_batch_size = rng.randint(1,4)\n",
" tr_batch_size = 2\n",
" mb_x_i,mb_y_i,mb_x_hat,mb_y_hat = get_minibatch(train_data, tr_batch_size, 4*y_dim)\n",
" np_lr = lr_schedule[i]\n",
" fd = {x_hat: mb_x_hat, y_hat: mb_y_hat, x: mb_x_i, y: mb_y_i,\n",
" tf_lr: np_lr,\n",
" phase:True}\n",
"\n",
" # train for one step\n",
" summary, _, tr_acc, tr_loss = \\\n",
" sess.run([merged, optimizer, query_accuracy, mean_log_loss], fd)\n",
" # accumulate for averages and write immediate summary to logs\n",
" writer.add_summary(summary, example_idx)\n",
" tr_acc_avg += tr_acc\n",
" tr_loss_avg += tr_loss\n",
" tr_n_batches += 1.\n",
" example_idx += 1\n",
"\n",
" if (i+1) % 1000 == 0:\n",
" # run on holdout set\n",
" test_acc_avg, test_loss_avg = 0., 0.\n",
" n_batches = 0.\n",
" for j in range(10):\n",
" mb_x_i,mb_y_i,mb_x_hat,mb_y_hat = get_minibatch(val_data, 1, y_dim)\n",
" fd = {x_hat: mb_x_hat, y_hat: mb_y_hat, x: mb_x_i, y: mb_y_i,\n",
" tf_lr: np_lr,\n",
" phase:True}\n",
" test_acc, test_loss = sess.run([query_accuracy,\n",
" mean_log_loss], fd)\n",
" test_acc_avg += test_acc\n",
" test_loss_avg += test_loss\n",
" n_batches += 1.\n",
"\n",
" test_acc_avg, test_loss_avg = map(lambda x:x/n_batches,\n",
" [test_acc_avg, test_loss_avg])\n",
" tr_acc_avg, tr_loss_avg = map(lambda x:x/tr_n_batches,\n",
" [tr_acc_avg, tr_loss_avg])\n",
" desc = 'Loss: %0.3f\\tAccuracy: %0.3f'%(tr_loss_avg, tr_acc_avg)\n",
" desc +='\\tTest Loss: %0.3f\\tTest Accuracy: %0.3f'%(test_loss_avg, test_acc_avg)\n",
" print(\"Step%i\\n\\t\"%(i+1)+desc)\n",
" # then log these results\n",
" write_scalar_summary(\"Test Loss\", test_loss_avg)\n",
" write_scalar_summary(\"Test Accuracy\", test_acc_avg)\n",
"\n",
" # training traces\n",
" write_scalar_summary(\"Train Loss\", tr_loss_avg)\n",
" write_scalar_summary(\"Train Accuracy\", tr_acc_avg)\n",
"\n",
" # reset accumulators\n",
" tr_acc_avg, tr_loss_avg = 0., 0.\n",
" tr_n_batches = 0.\n",
"\n",
" # show training loss and accuracy on progress bar\n",
" progress.set_postfix(loss=tr_loss, acc=tr_acc)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# run on holdout set\n",
"test_acc_avg, test_loss_avg = 0., 0.\n",
"n_batches = 0.\n",
"for j in range(100):\n",
" mb_x_i,mb_y_i,mb_x_hat,mb_y_hat = get_minibatch(val_data, 1, y_dim)\n",
" fd = {x_hat: mb_x_hat, y_hat: mb_y_hat, x: mb_x_i, y: mb_y_i,\n",
" tf_lr: np_lr,\n",
" phase:False}\n",
" test_acc, test_loss = sess.run([query_accuracy,\n",
" mean_log_loss], fd)\n",
" test_acc_avg += test_acc\n",
" test_loss_avg += test_loss\n",
" n_batches += 1.\n",
"\n",
"test_acc_avg, test_loss_avg = map(lambda x:x/n_batches,\n",
" [test_acc_avg, test_loss_avg])\n",
"print('Test Loss: %0.3f\\tTest Accuracy: %0.3f'%(test_loss_avg, test_acc_avg))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"!mkdir /disk/scratch/gavin/models/simpler_relational/"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# periodically save model\n",
"savedir = os.path.join(\"/disk/scratch/gavin/models/simpler_relational/\",\"simpler_relational.ckpt\")\n",
"if os.path.exists(savedir+\".index\"):\n",
" raise OSError(\"Don't overwrite your saved model, you idiot!\")\n",
"best_test_loss = test_loss_avg\n",
"save_path = saver.save(sess, savedir)\n",
"print(\" saved model checkpoint to: %s\"%save_path)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment