Skip to content

Instantly share code, notes, and snippets.

@wendazhou
Last active January 3, 2021 21:41
Show Gist options
  • Save wendazhou/4058d7d907813cd4cbb4bcbff0d5ce37 to your computer and use it in GitHub Desktop.
Save wendazhou/4058d7d907813cd4cbb4bcbff0d5ce37 to your computer and use it in GitHub Desktop.
Quick tensorflow tutorial
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# A Modern Introduction to Tensorflow\n",
"\n",
"This notebook is a quick introduction to tensorflow from the point of view of current best practices. In particular, it makes heavy use of the `tf.data` and `tf.estimator` APIs. This introduction pre-supposes some familiarity with both the [Python programming language](https://www.python.org/) and the mathematics of deep learning.\n",
"\n",
"## About Tensorflow\n",
"\n",
"[Tensorflow](https://www.tensorflow.org/) is a open source framework for high performance machine learning. It combines two important capabilities for machine learning: the ability to define computations in a symbolic manner, and in particular extract derivatives automatically by back-propagation, and the ability to execute these operations on a variety of hardware, including GPUs or TPUs. It is currently in very active development, with a minor version every couple of months, and best practices are evolving rapidly.\n",
"\n",
"There exists some higher-level APIs, such as [keras](https://keras.io/), which abstract away the usage of tensorflow. However, they tend to be less flexible and less adapted to making"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\wenda\\AppData\\Local\\conda\\conda\\envs\\idp\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
" from ._conv import register_converters as _register_converters\n"
]
}
],
"source": [
"# Let's get started. Everything is in the tensorflow package.\n",
"import tensorflow as tf\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import tempfile\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensor(\"mul:0\", shape=(), dtype=float32)\n"
]
}
],
"source": [
"# Everything in tensorflow lives in a graph.\n",
"# This graph defines the symbolic computations.\n",
"with tf.Graph().as_default():\n",
" x = tf.random_normal(shape=[])\n",
" y = 2 * x\n",
" \n",
"print(y)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-1.3369544\n",
"0.6518975\n"
]
}
],
"source": [
"# To execute computation, we need to use a Session. Usually, this will be encapsulated\n",
"# away from us by the API we are using. However, we create it explicitly here for simplicity.\n",
"with tf.Graph().as_default():\n",
" x = tf.random_normal(shape=[])\n",
" y = 2 * x\n",
" \n",
" with tf.Session() as session:\n",
" print(session.run(y))\n",
" print(session.run(y))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# The main APIs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The dataset API\n",
"\n",
"The dataset API is the main method and best practice to load data into tensorflow. It works with simple text files, binary files, and tfrecord files. You can also directly load data from numpy arrays, but this does not scale well. There are pre-existing files or conversion utilities for MNIST, CIFAR and Imagenet datasets. This tutorial uses the MNIST dataset, and we have included the code below to download and extract the dataset."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'data/train-labels-idx1-ubyte'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Download the data. Note that this bugs out sometimes but still downloads the data.\n",
"import mnist_data\n",
"mnist_data.download('data/', 'train-images-idx3-ubyte')\n",
"mnist_data.download('data/', 'train-labels-idx1-ubyte')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"b\"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x03\\x12\\x12\\x12~\\x88\\xaf\\x1a\\xa6\\xff\\xf7\\x7f\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x1e$^\\x9a\\xaa\\xfd\\xfd\\xfd\\xfd\\xfd\\xe1\\xac\\xfd\\xf2\\xc3@\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x001\\xee\\xfd\\xfd\\xfd\\xfd\\xfd\\xfd\\xfd\\xfd\\xfb]RR8'\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x12\\xdb\\xfd\\xfd\\xfd\\xfd\\xfd\\xc6\\xb6\\xf7\\xf1\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00P\\x9ck\\xfd\\xfd\\xcd\\x0b\\x00+\\x9a\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x0e\\x01\\x9a\\xfdZ\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x8b\\xfd\\xbe\\x02\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x0b\\xbe\\xfdF\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00#\\xf1\\xe1\\xa0l\\x01\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00Q\\xf0\\xfd\\xfdw\\x19\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00-\\xba\\xfd\\xfd\\x96\\x1b\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x10]\\xfc\\xfd\\xbb\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\xf9\\xfd\\xf9@\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00.\\x82\\xb7\\xfd\\xfd\\xcf\\x02\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00'\\x94\\xe5\\xfd\\xfd\\xfd\\xfa\\xb6\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x18r\\xdd\\xfd\\xfd\\xfd\\xfd\\xc9N\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x17B\\xd5\\xfd\\xfd\\xfd\\xfd\\xc6Q\\x02\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x12\\xab\\xdb\\xfd\\xfd\\xfd\\xfd\\xc3P\\t\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x007\\xac\\xe2\\xfd\\xfd\\xfd\\xfd\\xf4\\x85\\x0b\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x88\\xfd\\xfd\\xfd\\xd4\\x87\\x84\\x10\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n",
"784\n"
]
}
],
"source": [
"# Let's load the dataset and have a look through it.\n",
"# Currently, it's just a bunch of bytes. Not the most\n",
"# interesting.\n",
"\n",
"with tf.Graph().as_default():\n",
" dataset = tf.data.FixedLengthRecordDataset('data/train-images-idx3-ubyte', 28 * 28, header_bytes=16)\n",
" iterator = dataset.make_one_shot_iterator()\n",
" next_sample = iterator.get_next()\n",
" \n",
" with tf.Session() as session:\n",
" print(session.run(next_sample))\n",
" print(len(session.run(next_sample)))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAADu1JREFUeJzt3X+QVfV5x/HPw3bll+BIDUgIlqis\nhNIG4gZjTYKJowNJpuhMNWE6hlLTzUyixWjbOExn4qTTDs2YGJNgEhKJmERMZvzFdKjRUKbGhBAW\nNMGIRksW3UAhAi34C1n26R97SDe453sv9557z2Wf92uG2XvPc849z1z97Ll3v+ecr7m7AMQzouwG\nAJSD8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCOoPmrmzU2ykj9LYZu4SCOU1vazX/bBVs25d\n4Tez+ZJuk9Qm6Zvuvjy1/iiN1QV2ST27BJCwyddXvW7NH/vNrE3SCkkLJM2UtMjMZtb6egCaq57v\n/HMlPefuO9z9dUn3SFpYTFsAGq2e8E+R9MKg573Zst9jZl1m1m1m3Ud0uI7dAShSPeEf6o8Kb7g+\n2N1Xununu3e2a2QduwNQpHrC3ytp6qDnb5G0q752ADRLPeHfLGm6mb3VzE6R9BFJa4tpC0Cj1TzU\n5+59ZnatpB9oYKhvlbv/srDOADRUXeP87r5O0rqCegHQRJzeCwRF+IGgCD8QFOEHgiL8QFCEHwiK\n8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQTZ2i\nG8NP3/vPT9Z3fyJ/irafX7g6ue3bNy5O1t+84pRkvW3D1mQ9Oo78QFCEHwiK8ANBEX4gKMIPBEX4\ngaAIPxBUXeP8ZtYj6ZCko5L63L2ziKbQOvrnzUnWv7TqK8n6ue35/4v1V9j34xd+K1l/pvNosv73\n095VYQ+xFXGSz/vc/cUCXgdAE/GxHwiq3vC7pIfNbIuZdRXREIDmqPdj/0XuvsvMJkp6xMyedvdH\nB6+Q/VLokqRRGlPn7gAUpa4jv7vvyn7ulXS/pLlDrLPS3TvdvbNdI+vZHYAC1Rx+MxtrZuOOPZZ0\nmaQni2oMQGPV87F/kqT7zezY69zt7g8V0hWAhqs5/O6+Q9LbC+wFJThyWfrUjH+4/dvJekd7+pr6\n/sRo/o4jR5Lb/m9/+mvinArfIg8veGdubfSGbclt+197Lf3iwwBDfUBQhB8IivADQRF+ICjCDwRF\n+IGguHX3MNA2fnxu7eX3zkhu+6lb707W3zf6pQp7r/34ceeBP0vW199+YbL+45u/lKw/8s2v5dZm\nfufa5LZnf3pjsj4ccOQHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAY5x8Geu+aklvb/M4VTezkxHx2\n4uZk/aFT0+cBLOm5LFlfPe2HubXxM/clt42AIz8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBMU4/0mg\n7/3nJ+trZudPkz1C6VtrV7Jk5yXJevcP35asb7smv7cNr45Kbjux+9Vk/bkD6XsVtP/LhtzaCEtu\nGgJHfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8Iytw9vYLZKkkfkrTX3WdlyyZI+p6kaZJ6JF3l7gcq\n7Wy8TfALLD1uHFH/vDnJ+hdX356sn9te++kaf/70Fcl621+8nKzv/+B5yfq+WfkD6h0rXkhu2/dC\nb7Jeyb/9ZktubffR9DkEf734b5P1tg1ba+qp0Tb5eh30/VWdxVDNkf9OSfOPW3aTpPXuPl3S+uw5\ngJNIxfC7+6OS9h+3eKGk1dnj1ZIuL7gvAA1W63f+Se6+W5KynxOLawlAMzT83H4z65LUJUmjNKbR\nuwNQpVqP/HvMbLIkZT/35q3o7ivdvdPdO9s1ssbdAShareFfK2lx9nixpAeLaQdAs1QMv5mtkbRR\n0nlm1mtm10haLulSM3tW0qXZcwAnkYrf+d19UU6JAfsq2fl/nKy/eEN6zLmjPX1N/pbD+bX/eGlm\nctt990xN1v/wQHqe+tO+89N0PVHrS27ZWJPa0l9B913/SrI+Mf9WAScNzvADgiL8QFCEHwiK8ANB\nEX4gKMIPBMWtuwswYkz6tOW+zx1M1n86475k/dd9ryfrNyy7Mbd2+o+eT247cWzuyZmSpKPJ6vA1\nd/LOZL2nOW00FEd+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiKcf4CvDovfcnuD2akb71dyceWfipZ\nH/dA/mW1ZV42i9bGkR8IivADQRF+ICjCDwRF+IGgCD8QFOEHgmKcvwB/+k9PJOsjKvyOXbIzfRf0\n0Q/87IR7gtRubbm1I+mZ6dVmFVYYBjjyA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQFcf5zWyVpA9J\n2uvus7JlN0v6G0m/zVZb5u7rGtVkK/ifqy/Mrf3jpFuS2/arwhTbD6en0T5LP0nWMbQjnj/rQL/6\nk9s+tD3932S6ttbUUyup5sh/p6T5Qyy/1d1nZ/+GdfCB4ahi+N39UUn7m9ALgCaq5zv/tWb2CzNb\nZWanF9YRgKaoNfxflXSOpNmSdkv6fN6KZtZlZt1m1n1Eh2vcHYCi1RR+d9/j7kfdvV/SNyTNTay7\n0t073b2zXSNr7RNAwWoKv5lNHvT0CklPFtMOgGapZqhvjaSLJZ1hZr2SPiPpYjObLck1MFvxxxvY\nI4AGqBh+d180xOI7GtBLS+sbnV87bUR6HH/ja+mvO2fftSu972R1+BoxZkyy/vQtsyq8wpbcyl/u\nWJDccsbSXyfr+WcQnDw4ww8IivADQRF+ICjCDwRF+IGgCD8QFLfuboJ9R09N1vt29DSnkRZTaSjv\nmeV/kqw/vfAryfq/v3Jabm3XinOT2447kD/t+XDBkR8IivADQRF+ICjCDwRF+IGgCD8QFOEHgmKc\nvwn+7sdXJusdiUtPT3b98+bk1vbe8Gpy2+2d6XH8S7Z9OFkfO39Hbm2chv84fiUc+YGgCD8QFOEH\ngiL8QFCEHwiK8ANBEX4gKMb5q2X5pREVfofe9u41yfoKddTSUUvY+dn8qcsl6d6PfiG31tGevuX5\nO362OFl/8xVPJetI48gPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0FVHOc3s6mS7pJ0pqR+SSvd/TYz\nmyDpe5KmSeqRdJW7H2hcqyXz/FK/+pObzhu9L1m//s7zk/VzvpV+/fb/PpRb2zPvTcltJ3y4N1m/\n7qz1yfqCMel7Eax9eVJu7aPb5ie3PePrY5N11KeaI3+fpBvd/W2S3iXpk2Y2U9JNkta7+3RJ67Pn\nAE4SFcPv7rvdfWv2+JCk7ZKmSFooaXW22mpJlzeqSQDFO6Hv/GY2TdIcSZskTXL33dLALwhJE4tu\nDkDjVB1+MztV0r2Srnf3gyewXZeZdZtZ9xEdrqVHAA1QVfjNrF0Dwf+uu9+XLd5jZpOz+mRJe4fa\n1t1Xununu3e2a2QRPQMoQMXwm5lJukPSdncffInWWknHLrtaLOnB4tsD0CjVXNJ7kaSrJW0zsyey\nZcskLZf0fTO7RtLzktL3pw5slKXf5u2Xfi1Zf+w9o5L1Zw+fmVtbclpPctt6Ld31nmT9oZ/Mzq1N\nX8rts8tUMfzu/pjyr2a/pNh2ADQLZ/gBQRF+ICjCDwRF+IGgCD8QFOEHgjL3xLWqBRtvE/wCOzlH\nB9s6zsmtdazZmdz2X8/cWNe+K90avNIlxSmPH06/9qL/7ErWO5YM3+nFT0abfL0O+v7Ejeb/H0d+\nICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiKKbqrdPRX/5Vbe/bKacltZ153XbL+1FVfrqWlqsxY94lk\n/bzbX0nWOx5nHH+44sgPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0FxPT8wjHA9P4CKCD8QFOEHgiL8\nQFCEHwiK8ANBEX4gqIrhN7OpZrbBzLab2S/NbGm2/GYz+42ZPZH9+0Dj2wVQlGpu5tEn6UZ332pm\n4yRtMbNHstqt7n5L49oD0CgVw+/uuyXtzh4fMrPtkqY0ujEAjXVC3/nNbJqkOZI2ZYuuNbNfmNkq\nMzs9Z5suM+s2s+4jOlxXswCKU3X4zexUSfdKut7dD0r6qqRzJM3WwCeDzw+1nbuvdPdOd+9s18gC\nWgZQhKrCb2btGgj+d939Pkly9z3uftTd+yV9Q9LcxrUJoGjV/LXfJN0habu7f2HQ8smDVrtC0pPF\ntwegUar5a/9Fkq6WtM3MnsiWLZO0yMxmS3JJPZI+3pAOATRENX/tf0zSUNcHryu+HQDNwhl+QFCE\nHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoJo6RbeZ/VbSzkGL\nzpD0YtMaODGt2lur9iXRW62K7O2P3P1N1azY1PC/Yedm3e7eWVoDCa3aW6v2JdFbrcrqjY/9QFCE\nHwiq7PCvLHn/Ka3aW6v2JdFbrUrprdTv/ADKU/aRH0BJSgm/mc03s2fM7Dkzu6mMHvKYWY+Zbctm\nHu4uuZdVZrbXzJ4ctGyCmT1iZs9mP4ecJq2k3lpi5ubEzNKlvnetNuN10z/2m1mbpF9JulRSr6TN\nkha5+1NNbSSHmfVI6nT30seEzey9kl6SdJe7z8qWfU7Sfndfnv3iPN3dP90ivd0s6aWyZ27OJpSZ\nPHhmaUmXS/orlfjeJfq6SiW8b2Uc+edKes7dd7j765LukbSwhD5anrs/Kmn/cYsXSlqdPV6tgf95\nmi6nt5bg7rvdfWv2+JCkYzNLl/reJfoqRRnhnyLphUHPe9VaU367pIfNbIuZdZXdzBAmZdOmH5s+\nfWLJ/Ryv4szNzXTczNIt897VMuN10coI/1Cz/7TSkMNF7v4OSQskfTL7eIvqVDVzc7MMMbN0S6h1\nxuuilRH+XklTBz1/i6RdJfQxJHfflf3cK+l+td7sw3uOTZKa/dxbcj+/00ozNw81s7Ra4L1rpRmv\nywj/ZknTzeytZnaKpI9IWltCH29gZmOzP8TIzMZKukytN/vwWkmLs8eLJT1YYi+/p1Vmbs6bWVol\nv3etNuN1KSf5ZEMZX5TUJmmVu/9z05sYgpmdrYGjvTQwiendZfZmZmskXayBq772SPqMpAckfV/S\nWZKel3Sluzf9D285vV2sgY+uv5u5+dh37Cb39m5JP5K0TVJ/tniZBr5fl/beJfpapBLeN87wA4Li\nDD8gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0H9HxK6HmPNl2xnAAAAAElFTkSuQmCC\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x151b09c82b0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Datasets are similar to Python lists. Most of the ideas (filtering, mapping) carry across.\n",
"\n",
"with tf.Graph().as_default():\n",
" # Although our intuition carries across, it is important to always remember\n",
" # that we are working with symbolic computations. In particular, this function,\n",
" # which describes an operation that will be executed on each image, is only\n",
" # called once.\n",
" def _format_image(raw_data):\n",
" image = tf.decode_raw(raw_data, tf.uint8)\n",
" image = tf.to_float(image)\n",
" image = tf.reshape(image, [28, 28])\n",
" image = image / 255\n",
" return image\n",
" \n",
" dataset = tf.data.FixedLengthRecordDataset('data/train-images-idx3-ubyte', 28 * 28, header_bytes=16)\n",
" dataset = dataset.map(_format_image)\n",
" iterator = dataset.make_one_shot_iterator()\n",
" next_sample = iterator.get_next()\n",
" \n",
" with tf.Session() as session:\n",
" plt.imshow(session.run(next_sample))\n",
" plt.imshow(session.run(next_sample))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Let's put all this into a function.\n",
"def get_mnist_dataset():\n",
" \"\"\" This function creates a dataset which can be used to load data from the MNIST dataset. \"\"\"\n",
" def _format_image(raw_data):\n",
" image = tf.decode_raw(raw_data, tf.uint8)\n",
" image = tf.to_float(image)\n",
" image = tf.reshape(image, [28, 28, 1])\n",
" image = image / 255\n",
" return image\n",
" \n",
" def _format_label(raw_data):\n",
" label = tf.decode_raw(raw_data, tf.uint8)\n",
" label = tf.reshape(label, [])\n",
" return tf.to_int32(label)\n",
" \n",
" dataset_img = tf.data.FixedLengthRecordDataset('data/train-images-idx3-ubyte', 28 * 28, header_bytes=16)\n",
" dataset_img = dataset_img.map(_format_image)\n",
" \n",
" dataset_label = tf.data.FixedLengthRecordDataset('data/train-labels-idx1-ubyte', 1, header_bytes=8)\n",
" dataset_label = dataset_label.map(_format_label)\n",
" \n",
" return tf.data.Dataset.zip((dataset_img, dataset_label))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The Estimator API\n",
"\n",
"The estimator API is the main method to build and train estimators in tensorflow. In abstracts numerous details about setting up the tensorflow graph and executing the training, whilst being flexible enough to be workable for most experiments. One peculiarity of the estimator API is that it abstracts the model building behind a function, so we will be working with functions that return functions. Tensorflow also provides canned estimators, but they are not of interest today."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# The estimator API is based around the notion of a model_fn, which creates the\n",
"# model when called.\n",
"def model_fn(features, labels, mode):\n",
" images = features\n",
" images = tf.layers.flatten(images)\n",
" logits = tf.layers.dense(images, units=10)\n",
" \n",
" predictions = tf.argmax(logits, axis=1, output_type=tf.int32)\n",
" accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions)\n",
" \n",
" metrics = {'accuracy': accuracy}\n",
" \n",
" loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)\n",
" \n",
" optimizer = tf.train.AdamOptimizer()\n",
" train_op = optimizer.minimize(loss, global_step=tf.train.get_or_create_global_step())\n",
" \n",
" # This gives tensorflow the description of what must be done\n",
" # to construct our model.\n",
" return tf.estimator.EstimatorSpec(\n",
" loss=loss,\n",
" mode=mode,\n",
" train_op=train_op,\n",
" predictions=predictions,\n",
" eval_metric_ops=metrics)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From C:\\Users\\wenda\\AppData\\Local\\conda\\conda\\envs\\idp\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use the retry module or similar alternatives.\n"
]
}
],
"source": [
"# The other component of the estimator API is the input_fn, which creates the input\n",
"# pipeline for the model. We have done most of the work already, let's just batch it\n",
"# up and prefetch for performance\n",
"\n",
"# The tensorflow.contrib.data module contains numerous optimized transformations\n",
"# that can be applied to the dataset. Pass them to the apply function to use them.\n",
"from tensorflow.contrib import data as contrib_data\n",
"\n",
"def input_fn():\n",
" dataset = get_mnist_dataset()\n",
" dataset = dataset.apply(contrib_data.shuffle_and_repeat(1000)) # Shuffle the dataset, and repeat as necessary\n",
" dataset = dataset.prefetch(128) # Prefetch enough for a single batch for performance\n",
" dataset = dataset.batch(batch_size=128) # Batch it up\n",
" dataset = dataset.prefetch(2) # Prefetch two batches to device.\n",
" \n",
" return dataset"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Using default config.\n",
"WARNING:tensorflow:Using temporary folder as model directory: C:\\Users\\wenda\\AppData\\Local\\Temp\\tmpyq8gxkng\n",
"INFO:tensorflow:Using config: {'_model_dir': 'C:\\\\Users\\\\wenda\\\\AppData\\\\Local\\\\Temp\\\\tmpyq8gxkng', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x00000152085E1CC0>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n",
"INFO:tensorflow:Calling model_fn.\n",
"INFO:tensorflow:Done calling model_fn.\n",
"INFO:tensorflow:Create CheckpointSaverHook.\n",
"INFO:tensorflow:Graph was finalized.\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n",
"INFO:tensorflow:Saving checkpoints for 1 into C:\\Users\\wenda\\AppData\\Local\\Temp\\tmpyq8gxkng\\model.ckpt.\n",
"INFO:tensorflow:loss = 2.4807503, step = 0\n",
"INFO:tensorflow:global_step/sec: 268.016\n",
"INFO:tensorflow:loss = 0.79649985, step = 100 (0.373 sec)\n",
"INFO:tensorflow:global_step/sec: 264.947\n",
"INFO:tensorflow:loss = 0.5194293, step = 200 (0.377 sec)\n",
"INFO:tensorflow:global_step/sec: 283.507\n",
"INFO:tensorflow:loss = 0.60917985, step = 300 (0.353 sec)\n",
"INFO:tensorflow:global_step/sec: 282.303\n",
"INFO:tensorflow:loss = 0.40655237, step = 400 (0.354 sec)\n",
"INFO:tensorflow:Saving checkpoints for 500 into C:\\Users\\wenda\\AppData\\Local\\Temp\\tmpyq8gxkng\\model.ckpt.\n",
"INFO:tensorflow:Loss for final step: 0.4451155.\n"
]
},
{
"data": {
"text/plain": [
"<tensorflow.python.estimator.estimator.Estimator at 0x152085e1978>"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Everything's ready, let's try it out!\n",
"\n",
"# Note that we are passing functions as arguments to the constructor and train\n",
"# functions, and not their return values! This is very important, as tensorflow\n",
"# will be calling them.\n",
"estimator = tf.estimator.Estimator(model_fn=model_fn)\n",
"estimator.train(input_fn, steps=500)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Using default config.\n",
"INFO:tensorflow:Using config: {'_model_dir': 'C:\\\\Users\\\\wenda\\\\AppData\\\\Local\\\\Temp\\\\tmppmdt_spc', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x00000152089F5C18>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n",
"\n",
"---- Starting training ------\n",
"INFO:tensorflow:Calling model_fn.\n",
"INFO:tensorflow:Done calling model_fn.\n",
"INFO:tensorflow:Create CheckpointSaverHook.\n",
"INFO:tensorflow:Graph was finalized.\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n",
"INFO:tensorflow:Saving checkpoints for 1 into C:\\Users\\wenda\\AppData\\Local\\Temp\\tmppmdt_spc\\model.ckpt.\n",
"INFO:tensorflow:loss = 2.3400397, step = 0\n",
"INFO:tensorflow:global_step/sec: 271.353\n",
"INFO:tensorflow:loss = 0.64853764, step = 100 (0.369 sec)\n",
"INFO:tensorflow:global_step/sec: 284.66\n",
"INFO:tensorflow:loss = 0.59460866, step = 200 (0.356 sec)\n",
"INFO:tensorflow:global_step/sec: 280.029\n",
"INFO:tensorflow:loss = 0.43894875, step = 300 (0.357 sec)\n",
"INFO:tensorflow:global_step/sec: 279.368\n",
"INFO:tensorflow:loss = 0.44927344, step = 400 (0.358 sec)\n",
"INFO:tensorflow:Saving checkpoints for 500 into C:\\Users\\wenda\\AppData\\Local\\Temp\\tmppmdt_spc\\model.ckpt.\n",
"INFO:tensorflow:Loss for final step: 0.318319.\n",
"---- Training Done ------\n",
"\n",
"---- Starting evaluation ------\n",
"INFO:tensorflow:Calling model_fn.\n",
"INFO:tensorflow:Done calling model_fn.\n",
"INFO:tensorflow:Starting evaluation at 2018-04-10-22:23:51\n",
"INFO:tensorflow:Graph was finalized.\n",
"INFO:tensorflow:Restoring parameters from C:\\Users\\wenda\\AppData\\Local\\Temp\\tmppmdt_spc\\model.ckpt-500\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n",
"INFO:tensorflow:Finished evaluation at 2018-04-10-22:23:53\n",
"INFO:tensorflow:Saving dict for global step 500: accuracy = 0.89281666, global_step = 500, loss = 0.39535224\n"
]
},
{
"data": {
"text/plain": [
"{'accuracy': 0.89281666, 'global_step': 500, 'loss': 0.39535224}"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# In fact, the estimator works from a given directory. Let's save everything there.\n",
"# We will also want to evaluate our estimator. However, we must modify the input\n",
"# somewhat for that: we don't want to repeat the dataset forever when evaluating!\n",
"\n",
"# We now want a parametrised input function. So we will create\n",
"# a function that creates a function.\n",
"def make_input_fn(repeat_count=None, shuffle_size=1000):\n",
" def input_fn():\n",
" dataset = get_mnist_dataset()\n",
" # Shuffle the dataset, and repeat as necessary\n",
" if shuffle_size is not None and shuffle_size > 0:\n",
" from tensorflow.contrib.data import shuffle_and_repeat\n",
" dataset = dataset.apply(shuffle_and_repeat(shuffle_size, repeat_count))\n",
" else:\n",
" dataset = dataset.repeat(repeat_count)\n",
" dataset = dataset.prefetch(128) # Prefetch enough for a single batch for performance\n",
" dataset = dataset.batch(batch_size=128) # Batch it up\n",
" dataset = dataset.prefetch(2) # Prefetch two batches to device.\n",
" \n",
" return dataset\n",
" return input_fn\n",
"\n",
"estimator = tf.estimator.Estimator(model_fn=model_fn)\n",
"print('\\n---- Starting training ------')\n",
"estimator.train(make_input_fn(), steps=500)\n",
"print('---- Training Done ------')\n",
"print('\\n---- Starting evaluation ------')\n",
"estimator.evaluate(input_fn=make_input_fn(1, 0))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Using default config.\n",
"INFO:tensorflow:Using config: {'_model_dir': 'C:\\\\Users\\\\wenda\\\\AppData\\\\Local\\\\Temp\\\\tmpa3dalfqn', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x000001520AD64B38>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n",
"\n",
"---- Starting training ------\n",
"INFO:tensorflow:Calling model_fn.\n",
"INFO:tensorflow:Done calling model_fn.\n",
"INFO:tensorflow:Create CheckpointSaverHook.\n",
"INFO:tensorflow:Graph was finalized.\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n",
"INFO:tensorflow:Saving checkpoints for 1 into C:\\Users\\wenda\\AppData\\Local\\Temp\\tmpa3dalfqn\\model.ckpt.\n",
"INFO:tensorflow:loss = 2.4074893, step = 0\n",
"INFO:tensorflow:global_step/sec: 275.221\n",
"INFO:tensorflow:loss = 3.3473957, step = 100 (0.363 sec)\n",
"INFO:tensorflow:global_step/sec: 275.382\n",
"INFO:tensorflow:loss = 5.39507, step = 200 (0.363 sec)\n",
"INFO:tensorflow:global_step/sec: 282.829\n",
"INFO:tensorflow:loss = 4.791212, step = 300 (0.354 sec)\n",
"INFO:tensorflow:global_step/sec: 260.074\n",
"INFO:tensorflow:loss = 4.04576, step = 400 (0.385 sec)\n",
"INFO:tensorflow:Saving checkpoints for 500 into C:\\Users\\wenda\\AppData\\Local\\Temp\\tmpa3dalfqn\\model.ckpt.\n",
"INFO:tensorflow:Loss for final step: 3.730453.\n",
"---- Training Done ------\n"
]
}
],
"source": [
"# Suppose we wish to explore the possibility of different learning rates, how can we parametrise our\n",
"# estimators? Tensorflow allows us to pass a params parameter to our estimators. It is propagated to\n",
"# the model function.\n",
"\n",
"def model_fn(features, labels, mode, params):\n",
" # Here, the params parameter is passed in from tensorflow\n",
" images = features\n",
" images = tf.layers.flatten(images)\n",
" logits = tf.layers.dense(images, units=10)\n",
" \n",
" predictions = tf.argmax(logits, axis=1, output_type=tf.int32)\n",
" accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions)\n",
" \n",
" metrics = {'accuracy': accuracy}\n",
" \n",
" loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)\n",
" \n",
" optimizer = tf.train.AdamOptimizer(learning_rate=params['learning_rate'])\n",
" train_op = optimizer.minimize(loss, global_step=tf.train.get_or_create_global_step())\n",
" \n",
" # This gives tensorflow the description of what must be done\n",
" # to construct our model.\n",
" return tf.estimator.EstimatorSpec(\n",
" loss=loss,\n",
" mode=mode,\n",
" train_op=train_op,\n",
" predictions=predictions,\n",
" eval_metric_ops=metrics)\n",
"\n",
"tempdir = tempfile.mkdtemp()\n",
"estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=tempdir, params={'learning_rate': 0.5})\n",
"print('\\n---- Starting training ------')\n",
"estimator.train(make_input_fn(), steps=500)\n",
"print('---- Training Done ------')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Using default config.\n",
"INFO:tensorflow:Using config: {'_model_dir': 'C:\\\\Users\\\\wenda\\\\AppData\\\\Local\\\\Temp\\\\tmpg0isuv8y', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x0000015209AE6198>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n",
"\n",
"---- Starting training ------\n",
"INFO:tensorflow:Calling model_fn.\n",
"INFO:tensorflow:Done calling model_fn.\n",
"INFO:tensorflow:Create CheckpointSaverHook.\n",
"INFO:tensorflow:Graph was finalized.\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n",
"INFO:tensorflow:Saving checkpoints for 1 into C:\\Users\\wenda\\AppData\\Local\\Temp\\tmpg0isuv8y\\model.ckpt.\n",
"INFO:tensorflow:loss = 2.3851187, step = 0\n",
"INFO:tensorflow:global_step/sec: 273.729\n",
"INFO:tensorflow:loss = 0.68694407, step = 100 (0.370 sec)\n",
"INFO:tensorflow:global_step/sec: 258.145\n",
"INFO:tensorflow:loss = 0.24378338, step = 200 (0.387 sec)\n",
"INFO:tensorflow:global_step/sec: 272.136\n",
"INFO:tensorflow:loss = 0.372921, step = 300 (0.367 sec)\n",
"INFO:tensorflow:global_step/sec: 275.949\n",
"INFO:tensorflow:loss = 0.4812731, step = 400 (0.362 sec)\n",
"INFO:tensorflow:Saving checkpoints for 500 into C:\\Users\\wenda\\AppData\\Local\\Temp\\tmpg0isuv8y\\model.ckpt.\n",
"INFO:tensorflow:Loss for final step: 0.14714332.\n",
"---- Training Done ------\n"
]
}
],
"source": [
"# Suppose we wish to explore the possibility of different learning rates, how can we parametrise our\n",
"# estimators? Tensorflow allows us to pass a params parameter to our estimators. It is propagated to\n",
"# the model function.\n",
"\n",
"# What if we would like to use a non-constant learning rate? Say with some time decay?\n",
"# Unfortunately, we cannot create it before passing it to the estimator, as it would not\n",
"# be part of the same graph. Again, the solution is similar, pass it in a function!\n",
"\n",
"# In general, I like using the following function, which allows me to pass in constants\n",
"# or functions as I desire\n",
"\n",
"def _evaluate(fn_or_value):\n",
" if callable(fn_or_value):\n",
" return fn_or_value()\n",
" else:\n",
" return fn_or_value\n",
"\n",
"def model_fn(features, labels, mode, params):\n",
" # Here, the params parameter is passed in from tensorflow\n",
" images = features\n",
" images = tf.layers.flatten(images)\n",
" logits = tf.layers.dense(images, units=10)\n",
" \n",
" predictions = tf.argmax(logits, axis=1, output_type=tf.int32)\n",
" accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions)\n",
" \n",
" metrics = {'accuracy': accuracy}\n",
" \n",
" loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)\n",
" \n",
" optimizer = tf.train.AdamOptimizer(\n",
" learning_rate=_evaluate(params['learning_rate']))\n",
" train_op = optimizer.minimize(loss, global_step=tf.train.get_or_create_global_step())\n",
" \n",
" # This gives tensorflow the description of what must be done\n",
" # to construct our model.\n",
" return tf.estimator.EstimatorSpec(\n",
" loss=loss,\n",
" mode=mode,\n",
" train_op=train_op,\n",
" predictions=predictions,\n",
" eval_metric_ops=metrics)\n",
"\n",
"tempdir = tempfile.mkdtemp()\n",
"estimator = tf.estimator.Estimator(\n",
" model_fn=model_fn, model_dir=tempdir,\n",
" params={\n",
" 'learning_rate': lambda: tf.train.inverse_time_decay(\n",
" 0.1, tf.train.get_or_create_global_step(),\n",
" decay_steps=10,\n",
" decay_rate=1,\n",
" staircase=True)\n",
" })\n",
"print('\\n---- Starting training ------')\n",
"estimator.train(make_input_fn(), steps=500)\n",
"print('---- Training Done ------')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The summary APIs\n",
"\n",
"One of the strengths of tensorflow is `tensorboard`, a tool which allows us to visualize training. We are responsible to indicate the quantities we wish to record as we design our model, through the use of the\n",
"`tf.summary` API. There is also a slightly newer unstable API in `tf.contrib.summary` API, but it is not as\n",
"well integrated (mostly useful for TPU training). The summary APIs are able to record numerous different\n",
"types of data (scalars, histograms, images, audio). However, be aware that recording too much data can\n",
"signficantly impact performance."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Using default config.\n",
"INFO:tensorflow:Using config: {'_model_dir': 'C:\\\\Users\\\\wenda\\\\AppData\\\\Local\\\\Temp\\\\tmpytbj11mt', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x000001520868E748>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n",
"\n",
"---- Starting training ------\n",
"INFO:tensorflow:Calling model_fn.\n",
"INFO:tensorflow:Done calling model_fn.\n",
"INFO:tensorflow:Create CheckpointSaverHook.\n",
"INFO:tensorflow:Graph was finalized.\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n",
"INFO:tensorflow:Saving checkpoints for 1 into C:\\Users\\wenda\\AppData\\Local\\Temp\\tmpytbj11mt\\model.ckpt.\n",
"INFO:tensorflow:loss = 2.3976173, step = 0\n",
"INFO:tensorflow:global_step/sec: 231.885\n",
"INFO:tensorflow:loss = 0.31916904, step = 100 (0.434 sec)\n",
"INFO:tensorflow:global_step/sec: 202.021\n",
"INFO:tensorflow:loss = 0.6408271, step = 200 (0.495 sec)\n",
"INFO:tensorflow:global_step/sec: 220.275\n",
"INFO:tensorflow:loss = 0.54005706, step = 300 (0.454 sec)\n",
"INFO:tensorflow:global_step/sec: 231.179\n",
"INFO:tensorflow:loss = 0.36934638, step = 400 (0.434 sec)\n",
"INFO:tensorflow:Saving checkpoints for 500 into C:\\Users\\wenda\\AppData\\Local\\Temp\\tmpytbj11mt\\model.ckpt.\n",
"INFO:tensorflow:Loss for final step: 0.17129017.\n",
"---- Training Done ------\n"
]
}
],
"source": [
"# Suppose we wish to explore the possibility of different learning rates, how can we parametrise our\n",
"# estimators? Tensorflow allows us to pass a params parameter to our estimators. It is propagated to\n",
"# the model function.\n",
"\n",
"# What if we would like to use a non-constant learning rate? Say with some time decay?\n",
"# Unfortunately, we cannot create it before passing it to the estimator, as it would not\n",
"# be part of the same graph. Again, the solution is similar, pass it in a function!\n",
"\n",
"# In general, I like using the following function, which allows me to pass in constants\n",
"# or functions as I desire\n",
"\n",
"def _evaluate(fn_or_value):\n",
" if callable(fn_or_value):\n",
" return fn_or_value()\n",
" else:\n",
" return fn_or_value\n",
"\n",
"def model_fn(features, labels, mode, params):\n",
" # Here, the params parameter is passed in from tensorflow\n",
" images = features\n",
" images = tf.layers.flatten(images)\n",
" logits = tf.layers.dense(images, units=10)\n",
" \n",
" predictions = tf.argmax(logits, axis=1, output_type=tf.int32)\n",
" accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions)\n",
" \n",
" metrics = {'accuracy': accuracy}\n",
" \n",
" loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)\n",
" \n",
" tf.summary.scalar('loss', loss) # Record the loss\n",
" \n",
" learning_rate = _evaluate(params['learning_rate'])\n",
" optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)\n",
" train_op = optimizer.minimize(loss, global_step=tf.train.get_or_create_global_step())\n",
" \n",
" tf.summary.scalar('learning_rate', learning_rate) # Record the loss\n",
" \n",
" # This gives tensorflow the description of what must be done\n",
" # to construct our model.\n",
" return tf.estimator.EstimatorSpec(\n",
" loss=loss,\n",
" mode=mode,\n",
" train_op=train_op,\n",
" predictions=predictions,\n",
" eval_metric_ops=metrics)\n",
"\n",
"tempdir = tempfile.mkdtemp()\n",
"estimator = tf.estimator.Estimator(\n",
" model_fn=model_fn, model_dir=tempdir,\n",
" params={\n",
" 'learning_rate': lambda: tf.train.inverse_time_decay(\n",
" 0.1, tf.train.get_or_create_global_step(),\n",
" decay_steps=10,\n",
" decay_rate=1,\n",
" staircase=True)\n",
" })\n",
"print('\\n---- Starting training ------')\n",
"estimator.train(make_input_fn(), steps=500)\n",
"print('---- Training Done ------')"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensorboard --logdir='C:\\Users\\wenda\\AppData\\Local\\Temp\\tmpytbj11mt'\n"
]
}
],
"source": [
"# Run the command output by this cell to bring up tensorboard.\n",
"print(\"tensorboard --logdir='{0}'\".format(tempdir))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Layers\n",
"\n",
"Tensorflow provides numerous pre-existing layers that implement common neural network functionality. There are two main APIs in `tf.layers`, a functional API, usually in lowercase (e.g. `conv2d`), and a keras-like object API, in upper case (e.g. `Conv2D`). The functional API is somewhat sleeker, but does not provide access to the variables, whereas the object API does.\n",
"\n",
"A note about 2 dimensional convolutions: when working with images, there is often a choice to be made about the data format: whether to use NCHW or NHWC format. In general, tensorflow is faster (on GPU) with the NCHW format (with some exceptions for 1x1 convolutions). The best strategy is to create a network that can be instantiated with both formats."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Using default config.\n",
"INFO:tensorflow:Using config: {'_model_dir': 'C:\\\\Users\\\\wenda\\\\AppData\\\\Local\\\\Temp\\\\tmpv4hiowgr', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x00000152099FF2E8>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n",
"\n",
"---- Starting training ------\n",
"INFO:tensorflow:Calling model_fn.\n",
"INFO:tensorflow:Done calling model_fn.\n",
"INFO:tensorflow:Create CheckpointSaverHook.\n",
"INFO:tensorflow:Graph was finalized.\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n",
"INFO:tensorflow:Saving checkpoints for 1 into C:\\Users\\wenda\\AppData\\Local\\Temp\\tmpv4hiowgr\\model.ckpt.\n",
"INFO:tensorflow:loss = 2.295919, step = 0\n",
"INFO:tensorflow:global_step/sec: 128.522\n",
"INFO:tensorflow:loss = 1.8987412, step = 100 (0.779 sec)\n",
"INFO:tensorflow:global_step/sec: 148.252\n",
"INFO:tensorflow:loss = 1.6321212, step = 200 (0.676 sec)\n",
"INFO:tensorflow:global_step/sec: 154.076\n",
"INFO:tensorflow:loss = 1.4923857, step = 300 (0.649 sec)\n",
"INFO:tensorflow:global_step/sec: 150.277\n",
"INFO:tensorflow:loss = 1.4923654, step = 400 (0.665 sec)\n",
"INFO:tensorflow:Saving checkpoints for 500 into C:\\Users\\wenda\\AppData\\Local\\Temp\\tmpv4hiowgr\\model.ckpt.\n",
"INFO:tensorflow:Loss for final step: 1.6062481.\n",
"---- Training Done ------\n"
]
}
],
"source": [
"def _evaluate(fn_or_value):\n",
" if callable(fn_or_value):\n",
" return fn_or_value()\n",
" else:\n",
" return fn_or_value\n",
"\n",
"def model_fn(features, labels, mode, params):\n",
" data_format = params.get('data_format', 'channels_first')\n",
" images = tf.reshape(features, [-1, 28, 28, 1])\n",
" \n",
" \n",
" if data_format == 'channels_first':\n",
" # Transpose the data if necessary.\n",
" # Note: not strictly necessary here as we could just reshape\n",
" # but in general important for RGB data.\n",
" images = tf.transpose(images, [0, 3, 1, 2])\n",
" \n",
" # Use some convolutions and max pooling\n",
" x = images\n",
" x = tf.layers.conv2d(x, filters=20, kernel_size=3, strides=1,\n",
" padding='same', data_format=data_format)\n",
" x = tf.layers.max_pooling2d(x, pool_size=2, strides=2, data_format=data_format)\n",
" x = tf.layers.conv2d(x, filters=50, kernel_size=3, strides=1,\n",
" padding='same', data_format=data_format)\n",
" x = tf.layers.max_pooling2d(x, pool_size=2, strides=2, data_format=data_format)\n",
" x = tf.layers.average_pooling2d(x, pool_size=(7, 7), strides=(1, 1), data_format=data_format)\n",
" x = tf.layers.flatten(x)\n",
" x = tf.layers.dense(x, units=100)\n",
" logits = tf.layers.dense(x, units=10)\n",
" \n",
" predictions = tf.argmax(logits, axis=1, output_type=tf.int32)\n",
" accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions)\n",
" \n",
" metrics = {'accuracy': accuracy}\n",
" \n",
" loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)\n",
" \n",
" tf.summary.scalar('loss', loss) # Record the loss\n",
" \n",
" learning_rate = _evaluate(params['learning_rate'])\n",
" optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)\n",
" train_op = optimizer.minimize(loss, global_step=tf.train.get_or_create_global_step())\n",
" \n",
" tf.summary.scalar('learning_rate', learning_rate) # Record the loss\n",
" \n",
" # This gives tensorflow the description of what must be done\n",
" # to construct our model.\n",
" return tf.estimator.EstimatorSpec(\n",
" loss=loss,\n",
" mode=mode,\n",
" train_op=train_op,\n",
" predictions=predictions,\n",
" eval_metric_ops=metrics)\n",
"\n",
"tempdir = tempfile.mkdtemp()\n",
"estimator = tf.estimator.Estimator(\n",
" model_fn=model_fn, model_dir=tempdir,\n",
" params={\n",
" 'learning_rate': lambda: tf.train.inverse_time_decay(\n",
" 0.1, tf.train.get_or_create_global_step(),\n",
" decay_steps=10,\n",
" decay_rate=1,\n",
" staircase=True),\n",
" 'data_format': 'channels_first'\n",
" })\n",
"print('\\n---- Starting training ------')\n",
"estimator.train(make_input_fn(), steps=500)\n",
"print('---- Training Done ------')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""tf.data.Dataset interface to the MNIST dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip
import os
import shutil
import tempfile
import numpy as np
from six.moves import urllib
import tensorflow as tf
def read32(bytestream):
"""Read 4 bytes from bytestream as an unsigned 32-bit integer."""
dt = np.dtype(np.uint32).newbyteorder('>')
return np.frombuffer(bytestream.read(4), dtype=dt)[0]
def check_image_file_header(filename):
"""Validate that filename corresponds to images for the MNIST dataset."""
with tf.gfile.Open(filename, 'rb') as f:
magic = read32(f)
read32(f) # num_images, unused
rows = read32(f)
cols = read32(f)
if magic != 2051:
raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
f.name))
if rows != 28 or cols != 28:
raise ValueError(
'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' %
(f.name, rows, cols))
def check_labels_file_header(filename):
"""Validate that filename corresponds to labels for the MNIST dataset."""
with tf.gfile.Open(filename, 'rb') as f:
magic = read32(f)
read32(f) # num_items, unused
if magic != 2049:
raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
f.name))
def download(directory, filename):
"""Download (and unzip) a file from the MNIST dataset if not already done."""
filepath = os.path.join(directory, filename)
if tf.gfile.Exists(filepath):
return filepath
if not tf.gfile.Exists(directory):
tf.gfile.MakeDirs(directory)
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'
_, zipped_filepath = tempfile.mkstemp(suffix='.gz')
print('Downloading %s to %s' % (url, zipped_filepath))
urllib.request.urlretrieve(url, zipped_filepath)
with gzip.open(zipped_filepath, 'rb') as f_in, \
tf.gfile.Open(filepath, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
os.remove(zipped_filepath)
return filepath
def dataset(directory, images_file, labels_file):
"""Download and parse MNIST dataset."""
images_file = download(directory, images_file)
labels_file = download(directory, labels_file)
check_image_file_header(images_file)
check_labels_file_header(labels_file)
def decode_image(image):
# Normalize from [0, 255] to [0.0, 1.0]
image = tf.decode_raw(image, tf.uint8)
image = tf.cast(image, tf.float32)
image = tf.reshape(image, [784])
return image / 255.0
def decode_label(label):
label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8]
label = tf.reshape(label, []) # label is a scalar
return tf.to_int32(label)
images = tf.data.FixedLengthRecordDataset(
images_file, 28 * 28, header_bytes=16).map(decode_image)
labels = tf.data.FixedLengthRecordDataset(
labels_file, 1, header_bytes=8).map(decode_label)
return tf.data.Dataset.zip((images, labels))
def train(directory):
"""tf.data.Dataset object for MNIST training data."""
return dataset(directory, 'train-images-idx3-ubyte',
'train-labels-idx1-ubyte')
def test(directory):
"""tf.data.Dataset object for MNIST test data."""
return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment