Skip to content

Instantly share code, notes, and snippets.

@julie-is-late
Created October 24, 2016 23:14
Show Gist options
  • Save julie-is-late/4012ccff769bc92cb188370f3ec729cb to your computer and use it in GitHub Desktop.
Save julie-is-late/4012ccff769bc92cb188370f3ec729cb to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"import pandas as pd\n",
"from sklearn.preprocessing import scale, OneHotEncoder\n",
"from sklearn.metrics import accuracy_score\n",
"import time, math"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def MinBatch(batch_size, n):\n",
" ix = np.random.permutation(n)\n",
" k = np.empty([math.ceil(float(n) / batch_size)], dtype=object)\n",
" for y in range(0, math.ceil(n / batch_size)):\n",
" k[y] = np.array([], dtype=int)\n",
" for z in range(0, batch_size):\n",
" if (y * batch_size + z > n - 1):\n",
" break\n",
" k[y] = np.append(k[y], ix[y * batch_size + z])\n",
" return k"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"MNIST = np.load('../data/MNIST_train_40000.npz')\n",
"\n",
"images_train = MNIST['train_images']\n",
"labels_train = MNIST['train_labels']\n",
"\n",
"inputs_train = images_train.reshape(-1,images_train.shape[1]*images_train.shape[2])\n",
"\n",
"# one-hot-encode labels\n",
"outputs_train = OneHotEncoder(sparse=False).fit_transform(labels_train.reshape(-1,1))\n",
"# outputs_train = pd.get_dummies(labels_train).values"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def inittensors():\n",
" layer_node_count = 1024\n",
"\n",
" x = tf.placeholder(tf.float32, shape=[None, inputs_train.shape[1]])\n",
" y = tf.placeholder(tf.float32, shape=[None, 10])\n",
"\n",
" w1 = tf.Variable(tf.truncated_normal([inputs_train.shape[1],layer_node_count], stddev=0.1, seed=0))\n",
" b1 = tf.Variable(tf.truncated_normal([1,layer_node_count], stddev=0.1, seed=0))\n",
"\n",
" w2 = tf.Variable(tf.truncated_normal([layer_node_count,10], stddev=0.1, seed=0))\n",
" b2 = tf.Variable(tf.truncated_normal([1,10], stddev=0.1, seed=0))\n",
"\n",
" layer1 = tf.matmul(x,w1) + b1\n",
" layer1 = tf.nn.relu(layer1)\n",
" layer2 = tf.matmul(layer1,w2) + b2\n",
"\n",
" CE = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(layer2, y))\n",
"\n",
" optimizer = tf.train.AdamOptimizer().minimize(CE)\n",
"\n",
" y_pred = tf.nn.softmax(tf.matmul(tf.nn.relu(tf.matmul(x,w1) + b1),w2) + b2)\n",
"\n",
" # initialization of variables\n",
" init = tf.initialize_all_variables()\n",
"\n",
" # initialize a computation session\n",
" sess = tf.Session()\n",
" sess.run(init)\n",
" return sess, optimizer, CE, y_pred, x, y"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def runBatch(BATCH_SIZE):\n",
" sess, optimizer, CE, y_pred, x, y = inittensors()\n",
" np.random.seed(0)\n",
" \n",
" runtime = 0\n",
" last_full_run = time.time()\n",
" \n",
" epoch = 0\n",
" ce_full = float(\"inf\")\n",
" prev_epoch = -1\n",
" \n",
" done = False\n",
" \n",
" while not done:\n",
" curr_batch = MinBatch(BATCH_SIZE, images_train.shape[0])\n",
" \n",
" for batch in range(curr_batch.shape[0]):\n",
" start = time.time()\n",
" sess.run([optimizer],\n",
" feed_dict={ x : inputs_train[curr_batch[batch]],\n",
" y : outputs_train[curr_batch[batch]] })\n",
" runtime += time.time() - start\n",
" \n",
" curr_time = time.time()\n",
" if (curr_time - last_full_run) / 60 > 0.1:\n",
" last_full_run = curr_time\n",
" ce_full = sess.run(CE,\n",
" feed_dict={ x : inputs_train,\n",
" y : outputs_train })\n",
" \n",
" if epoch != prev_epoch:\n",
" print('\\t epoch = %d: ' % (epoch))\n",
" prev_epoch = epoch\n",
" \n",
" print('\\t\\t batch = %6d -' % (batch),\n",
" 'cross entropy = %.5f' % (ce_full))\n",
" \n",
" if (ce_full < 0.01):\n",
" done = True\n",
" break\n",
" \n",
" epoch += 1\n",
"\n",
" \n",
" runtime = runtime / 60\n",
" \n",
" y_p, ce_full = sess.run([y_pred, CE],\n",
" feed_dict={ x : inputs_train,\n",
" y : outputs_train })\n",
"\n",
" labels_train_pred = np.argmax(y_p, axis=1)\n",
" err_train = 1 - accuracy_score(np.argmax(outputs_train, axis=1), labels_train_pred)\n",
" \n",
" print('num epochs = %d: \\n' % (epoch) +\n",
" ' training cross-entropy = %-10.5f \\n' % (ce_full) + \n",
" ' training error rate = %-10.5f \\n' % (err_train) + \n",
" ' total runtime (minutes) = %-10.2f' % runtime)\n",
" \n",
" print('Done!')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"batchsize: 10\n",
"\t epoch = 0: \n",
"\t\t batch = 1322 - cross entropy = 0.19082\n",
"\t\t batch = 2520 - cross entropy = 0.10464\n",
"\t\t batch = 3679 - cross entropy = 0.09296\n",
"\t epoch = 1: \n",
"\t\t batch = 794 - cross entropy = 0.06923\n",
"\t\t batch = 1969 - cross entropy = 0.07965\n",
"\t\t batch = 3150 - cross entropy = 0.07305\n",
"\t epoch = 2: \n",
"\t\t batch = 230 - cross entropy = 0.04002\n",
"\t\t batch = 1380 - cross entropy = 0.03867\n",
"\t\t batch = 2535 - cross entropy = 0.05152\n",
"\t\t batch = 3675 - cross entropy = 0.03350\n",
"\t epoch = 3: \n",
"\t\t batch = 728 - cross entropy = 0.03371\n",
"\t\t batch = 1823 - cross entropy = 0.05332\n",
"\t\t batch = 2962 - cross entropy = 0.03908\n",
"\t epoch = 4: \n",
"\t\t batch = 17 - cross entropy = 0.02980\n",
"\t\t batch = 1150 - cross entropy = 0.02370\n",
"\t\t batch = 2280 - cross entropy = 0.03139\n",
"\t\t batch = 3397 - cross entropy = 0.02774\n",
"\t epoch = 5: \n",
"\t\t batch = 441 - cross entropy = 0.02800\n",
"\t\t batch = 1558 - cross entropy = 0.01768\n",
"\t\t batch = 2673 - cross entropy = 0.01929\n",
"\t\t batch = 3793 - cross entropy = 0.02150\n",
"\t epoch = 6: \n",
"\t\t batch = 831 - cross entropy = 0.01521\n",
"\t\t batch = 1946 - cross entropy = 0.01594\n",
"\t\t batch = 3029 - cross entropy = 0.03753\n",
"\t epoch = 7: \n",
"\t\t batch = 72 - cross entropy = 0.02615\n",
"\t\t batch = 1178 - cross entropy = 0.01346\n",
"\t\t batch = 2276 - cross entropy = 0.01034\n",
"\t\t batch = 3377 - cross entropy = 0.02299\n",
"\t epoch = 8: \n",
"\t\t batch = 397 - cross entropy = 0.03562\n",
"\t\t batch = 1494 - cross entropy = 0.01369\n",
"\t\t batch = 2586 - cross entropy = 0.02296\n",
"\t\t batch = 3676 - cross entropy = 0.01816\n",
"\t epoch = 9: \n",
"\t\t batch = 662 - cross entropy = 0.01188\n",
"\t\t batch = 1660 - cross entropy = 0.02637\n",
"\t\t batch = 2736 - cross entropy = 0.01287\n",
"\t\t batch = 3815 - cross entropy = 0.01446\n",
"\t epoch = 10: \n",
"\t\t batch = 823 - cross entropy = 0.00638\n",
"num epochs = 11: \n",
" training cross-entropy = 0.00638 \n",
" training error rate = 0.00208 \n",
" total runtime (minutes) = 3.17 \n",
"Done!\n",
"batchsize: 100\n",
"\t epoch = 1: \n",
"\t\t batch = 210 - cross entropy = 0.09613\n",
"\t epoch = 2: \n",
"\t\t batch = 369 - cross entropy = 0.03662\n",
"\t epoch = 4: \n",
"\t\t batch = 75 - cross entropy = 0.01671\n",
"\t epoch = 5: \n",
"\t\t batch = 227 - cross entropy = 0.01327\n",
"\t epoch = 6: \n",
"\t\t batch = 379 - cross entropy = 0.00814\n",
"num epochs = 7: \n",
" training cross-entropy = 0.00814 \n",
" training error rate = 0.00162 \n",
" total runtime (minutes) = 0.41 \n",
"Done!\n",
"batchsize: 1000\n",
"\t epoch = 2: \n",
"\t\t batch = 5 - cross entropy = 0.16706\n",
"\t epoch = 4: \n",
"\t\t batch = 3 - cross entropy = 0.10234\n",
"\t epoch = 6: \n",
"\t\t batch = 2 - cross entropy = 0.06757\n",
"\t epoch = 8: \n",
"\t\t batch = 1 - cross entropy = 0.04785\n",
"\t epoch = 10: \n",
"\t\t batch = 1 - cross entropy = 0.03324\n",
"\t epoch = 12: \n",
"\t\t batch = 0 - cross entropy = 0.02497\n",
"\t epoch = 14: \n",
"\t\t batch = 0 - cross entropy = 0.01869\n",
"\t epoch = 16: \n",
"\t\t batch = 0 - cross entropy = 0.01453\n",
"\t epoch = 18: \n",
"\t\t batch = 0 - cross entropy = 0.01089\n",
"\t epoch = 20: \n",
"\t\t batch = 1 - cross entropy = 0.00861\n",
"num epochs = 21: \n",
" training cross-entropy = 0.00861 \n",
" training error rate = 0.00018 \n",
" total runtime (minutes) = 0.77 \n",
"Done!\n",
"batchsize: 10000\n",
"\t epoch = 2: \n",
"\t\t batch = 0 - cross entropy = 0.59122\n",
"\t epoch = 4: \n",
"\t\t batch = 1 - cross entropy = 0.37145\n",
"\t epoch = 6: \n",
"\t\t batch = 1 - cross entropy = 0.30379\n",
"\t epoch = 8: \n",
"\t\t batch = 1 - cross entropy = 0.26197\n",
"\t epoch = 10: \n",
"\t\t batch = 1 - cross entropy = 0.22950\n",
"\t epoch = 12: \n",
"\t\t batch = 1 - cross entropy = 0.20526\n",
"\t epoch = 14: \n",
"\t\t batch = 1 - cross entropy = 0.18593\n",
"\t epoch = 16: \n",
"\t\t batch = 2 - cross entropy = 0.16816\n",
"\t epoch = 18: \n",
"\t\t batch = 2 - cross entropy = 0.15494\n",
"\t epoch = 20: \n",
"\t\t batch = 3 - cross entropy = 0.14198\n",
"\t epoch = 22: \n",
"\t\t batch = 3 - cross entropy = 0.13179\n",
"\t epoch = 25: \n",
"\t\t batch = 0 - cross entropy = 0.12163\n",
"\t epoch = 27: \n",
"\t\t batch = 1 - cross entropy = 0.11251\n",
"\t epoch = 29: \n",
"\t\t batch = 1 - cross entropy = 0.10503\n",
"\t epoch = 31: \n",
"\t\t batch = 1 - cross entropy = 0.09825\n",
"\t epoch = 33: \n",
"\t\t batch = 1 - cross entropy = 0.09196\n",
"\t epoch = 35: \n",
"\t\t batch = 1 - cross entropy = 0.08620\n",
"\t epoch = 37: \n",
"\t\t batch = 1 - cross entropy = 0.08080\n",
"\t epoch = 39: \n",
"\t\t batch = 2 - cross entropy = 0.07520\n",
"\t epoch = 41: \n",
"\t\t batch = 2 - cross entropy = 0.07073\n",
"\t epoch = 43: \n",
"\t\t batch = 2 - cross entropy = 0.06652\n",
"\t epoch = 45: \n",
"\t\t batch = 2 - cross entropy = 0.06255\n",
"\t epoch = 47: \n",
"\t\t batch = 2 - cross entropy = 0.05885\n",
"\t epoch = 49: \n",
"\t\t batch = 3 - cross entropy = 0.05506\n",
"\t epoch = 52: \n",
"\t\t batch = 0 - cross entropy = 0.05153\n",
"\t epoch = 54: \n",
"\t\t batch = 1 - cross entropy = 0.04825\n",
"\t epoch = 56: \n",
"\t\t batch = 2 - cross entropy = 0.04530\n",
"\t epoch = 58: \n",
"\t\t batch = 2 - cross entropy = 0.04279\n",
"\t epoch = 60: \n",
"\t\t batch = 2 - cross entropy = 0.04049\n",
"\t epoch = 62: \n",
"\t\t batch = 3 - cross entropy = 0.03802\n",
"\t epoch = 64: \n",
"\t\t batch = 3 - cross entropy = 0.03607\n",
"\t epoch = 66: \n",
"\t\t batch = 3 - cross entropy = 0.03419\n",
"\t epoch = 69: \n",
"\t\t batch = 0 - cross entropy = 0.03212\n",
"\t epoch = 71: \n",
"\t\t batch = 0 - cross entropy = 0.03049\n",
"\t epoch = 73: \n",
"\t\t batch = 1 - cross entropy = 0.02874\n",
"\t epoch = 75: \n",
"\t\t batch = 2 - cross entropy = 0.02712\n",
"\t epoch = 77: \n",
"\t\t batch = 3 - cross entropy = 0.02564\n",
"\t epoch = 80: \n",
"\t\t batch = 0 - cross entropy = 0.02423\n",
"\t epoch = 82: \n",
"\t\t batch = 1 - cross entropy = 0.02296\n",
"\t epoch = 84: \n",
"\t\t batch = 2 - cross entropy = 0.02177\n",
"\t epoch = 86: \n",
"\t\t batch = 2 - cross entropy = 0.02073\n",
"\t epoch = 88: \n",
"\t\t batch = 3 - cross entropy = 0.01966\n",
"\t epoch = 91: \n",
"\t\t batch = 0 - cross entropy = 0.01864\n",
"\t epoch = 93: \n",
"\t\t batch = 0 - cross entropy = 0.01783\n",
"\t epoch = 95: \n",
"\t\t batch = 0 - cross entropy = 0.01702\n",
"\t epoch = 97: \n",
"\t\t batch = 1 - cross entropy = 0.01622\n",
"\t epoch = 99: \n",
"\t\t batch = 2 - cross entropy = 0.01542\n",
"\t epoch = 101: \n",
"\t\t batch = 3 - cross entropy = 0.01467\n",
"\t epoch = 104: \n",
"\t\t batch = 0 - cross entropy = 0.01397\n",
"\t epoch = 106: \n",
"\t\t batch = 0 - cross entropy = 0.01341\n",
"\t epoch = 108: \n",
"\t\t batch = 1 - cross entropy = 0.01279\n",
"\t epoch = 110: \n",
"\t\t batch = 2 - cross entropy = 0.01224\n",
"\t epoch = 112: \n",
"\t\t batch = 3 - cross entropy = 0.01166\n",
"\t epoch = 115: \n",
"\t\t batch = 0 - cross entropy = 0.01117\n",
"\t epoch = 117: \n",
"\t\t batch = 1 - cross entropy = 0.01069\n",
"\t epoch = 119: \n",
"\t\t batch = 1 - cross entropy = 0.01032\n",
"\t epoch = 121: \n",
"\t\t batch = 2 - cross entropy = 0.00987\n",
"num epochs = 122: \n",
" training cross-entropy = 0.00987 \n",
" training error rate = 0.00033 \n",
" total runtime (minutes) = 4.52 \n",
"Done!\n"
]
}
],
"source": [
"for i in range(1, 5):\n",
" batchsize = int(math.pow(10,i))\n",
" print('batchsize:', batchsize)\n",
" runBatch(batchsize)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"------------------------------"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"| batch size | 10 | 100 | 1000 | 1000 |\n",
"| ----------- | ------: | ------: | ------: | ------: |\n",
"| number of epochs | 11 | 7 | 21 | 122 |\n",
"| final training cross-entropy* | 0.00638 | 0.00814 | 0.00861 | 0.00987 |\n",
"| total elapsed optimization time (minutes) | 3.17 | 0.41 | 0.77 | 4.52 |\n",
"\n",
"What is your computing device? - laptop (W530)\n",
"\n",
"`*note that I stopped at the end of the epoch of \n",
" the first time CE < 0.01, even if it was \n",
" mid-epoch. If I have it wait until that \n",
" condition is true at the end of the poch it \n",
" takes anywhere between 1.5x and 3x the amount \n",
" of time to complete`"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [conda root]",
"language": "python",
"name": "conda-root-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment