Skip to content

Instantly share code, notes, and snippets.

@AidanRocke
Created August 20, 2018 12:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AidanRocke/0437b624b68337aa0de33d7810146316 to your computer and use it in GitHub Desktop.
Save AidanRocke/0437b624b68337aa0de33d7810146316 to your computer and use it in GitHub Desktop.
Correct example of accumulated gradients
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"source": [
"from tensorflow.examples.tutorials.mnist import input_data\n",
"mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=True)"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"WARNING:tensorflow:From <ipython-input-1-8bf8ae5a5303>:2: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
"WARNING:tensorflow:From /Users/aidanrockea/anaconda/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please write your own downloading logic.\n",
"WARNING:tensorflow:From /Users/aidanrockea/anaconda/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use tf.data to implement this functionality.\n",
"Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
"WARNING:tensorflow:From /Users/aidanrockea/anaconda/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use tf.data to implement this functionality.\n",
"Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
"WARNING:tensorflow:From /Users/aidanrockea/anaconda/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use tf.one_hot on tensors.\n",
"Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
"Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n",
"WARNING:tensorflow:From /Users/aidanrockea/anaconda/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n"
]
}
],
"execution_count": 1,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "code",
"source": [
"#!/usr/bin/env python3\n",
"# -*- coding: utf-8 -*-\n",
"\"\"\"\n",
"Created on Sat May 5 23:14:10 2018\n",
"\n",
"@author: aidanrocke\n",
"\"\"\"\n",
"\n",
"import tensorflow as tf\n",
"\n",
"class mnist_network:\n",
" \n",
" def __init__(self,seed):\n",
" self.seed = seed\n",
" self.x = tf.placeholder(tf.float32, [None, 784])\n",
" self.y_ = tf.placeholder(tf.float32, [None, 10])\n",
" self.y = self.mnist_net()\n",
" self.cross_entropy = tf.reduce_mean(-tf.reduce_sum(self.y_ * tf.log(self.y), reduction_indices=[1]))\n",
"\n",
" self.correct_prediction = tf.equal(tf.argmax(self.y, 1), tf.argmax(self.y_, 1))\n",
" self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))\n",
"\n",
" ## get trainable variables:\n",
" self.TV = tf.get_collection(key = tf.GraphKeys.TRAINABLE_VARIABLES,\n",
" scope= \"mnist_net\")\n",
" \n",
" ## define training operations:\n",
" self.optimizer = tf.train.AdagradOptimizer(0.01)\n",
" self.accum_vars = [tf.Variable(tf.zeros_like(tv.initialized_value()), trainable=False) for tv in self.TV] \n",
" self.zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in self.accum_vars]\n",
" self.gvs = self.optimizer.compute_gradients(self.cross_entropy, self.TV)\n",
" self.accum_ops = [self.accum_vars[i].assign_add(gv[0]) for i, gv in enumerate(self.gvs)]\n",
" self.train_step = self.optimizer.apply_gradients([(self.accum_vars[i], gv[1]) for i, gv in enumerate(self.gvs)])\n",
"\n",
" def init_weights(self,shape,var_name):\n",
" \"\"\"\n",
" Xavier initialisation of neural networks\n",
" \"\"\"\n",
" initializer = tf.contrib.layers.xavier_initializer(seed=self.seed)\n",
" \n",
" return tf.Variable(initializer(shape),name = var_name)\n",
" \n",
" def mnist_net(self):\n",
" \n",
" with tf.variable_scope(\"mnist_net\"):\n",
" \n",
" tf.set_random_seed(self.seed)\n",
" \n",
" w_h = self.init_weights([784, 1200],\"W_h\")\n",
" w_h2 = self.init_weights([1200, 1200],\"W_h2\")\n",
" w_h3 = self.init_weights([1200,10],\"W_h3\")\n",
" \n",
" # define bias terms:\n",
" bias_1 = self.init_weights([1200],\"bias_1\")\n",
" bias_2 = self.init_weights([1200],\"bias_2\")\n",
" bias_3 = self.init_weights([10],\"bias_3\")\n",
" \n",
" h = tf.nn.elu(tf.add(tf.matmul(self.x, w_h),bias_1))\n",
" h2 = tf.nn.elu(tf.add(tf.matmul(h, w_h2),bias_2))\n",
" \n",
" return tf.nn.softmax(tf.add(tf.matmul(h2, w_h3),bias_3))\n"
],
"outputs": [],
"execution_count": 2,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"\n",
"batch_size = 10\n",
"\n",
"tf.reset_default_graph()\n",
"\n",
"model = mnist_network(42)\n",
"\n",
"with tf.Session() as sess:\n",
" \n",
" ### initialise the variables:\n",
" sess.run(tf.global_variables_initializer())\n",
" \n",
" for i in range(700):\n",
" \n",
" sess.run(model.zero_ops)\n",
"\n",
" for j in range(5):\n",
"\n",
" input_images, correct_predictions = mnist.train.next_batch(batch_size)\n",
"\n",
" train_feed = {model.x: np.reshape(input_images,[-1,784]), model.y_: np.reshape(correct_predictions,[-1,10])}\n",
"\n",
" sess.run(model.accum_ops,feed_dict = train_feed)\n",
" \n",
" # check accuracy:\n",
" if i % 50 == 0:\n",
" train_accuracy = sess.run(model.accuracy, feed_dict={model.x: np.reshape(input_images,[-1,784]), model.y_: np.reshape(correct_predictions,[-1,10])})\n",
" print(\"step %d, training accuracy %.2f\" % (i, train_accuracy))\n",
" \n",
" sess.run(model.train_step)"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"step 0, training accuracy 0.00\n",
"step 50, training accuracy 0.90\n",
"step 100, training accuracy 1.00\n",
"step 150, training accuracy 0.80\n",
"step 200, training accuracy 0.90\n",
"step 250, training accuracy 1.00\n",
"step 300, training accuracy 0.80\n",
"step 350, training accuracy 0.90\n",
"step 400, training accuracy 0.90\n",
"step 450, training accuracy 0.90\n",
"step 500, training accuracy 1.00\n",
"step 550, training accuracy 0.90\n",
"step 600, training accuracy 0.90\n",
"step 650, training accuracy 1.00\n"
]
}
],
"execution_count": 3,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "code",
"source": [],
"outputs": [],
"execution_count": null,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
}
],
"metadata": {
"kernel_info": {
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.6.1",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"kernelspec": {
"name": "python3",
"language": "python",
"display_name": "Python 3"
},
"nteract": {
"version": "0.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment