Created
February 6, 2019 10:17
-
-
Save neil-tan/c132b7b82273c30038fe0da0e05b8f96 to your computer and use it in GitHub Desktop.
This file is located at uTensor/utensor-mnist-demo
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-04-16T14:29:56.593505Z", | |
"start_time": "2018-04-16T14:29:52.879458Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"# This script is based on:\n", | |
"# https://www.tensorflow.org/get_started/mnist/pros\n", | |
"\n", | |
"import sys\n", | |
"import tensorflow as tf\n", | |
"from tensorflow.examples.tutorials.mnist import input_data\n", | |
"from tensorflow.python.framework import graph_util as gu\n", | |
"from tensorflow.python.framework.graph_util import remove_training_nodes\n", | |
"from tensorflow.tools.graph_transforms import TransformGraph" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Import training data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-04-16T14:29:57.272673Z", | |
"start_time": "2018-04-16T14:29:56.595420Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"mnist = input_data.read_data_sets(\"mnist_data/\", one_hot=True)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Define Tensorflow Model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-04-16T14:29:57.278092Z", | |
"start_time": "2018-04-16T14:29:57.274735Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"batch_size = 50" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Fully connected 2 layer NN" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-04-16T14:29:57.288994Z", | |
"start_time": "2018-04-16T14:29:57.280660Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"def deepnn(x):\n", | |
" W_fc1 = weight_variable([784, 128], name='W_fc1')\n", | |
" b_fc1 = bias_variable([128], name='b_fc1')\n", | |
" a_fc1 = tf.add(tf.matmul(x, W_fc1), b_fc1, name=\"zscore\")\n", | |
" h_fc1 = tf.nn.relu(a_fc1)\n", | |
"\n", | |
" W_fc2 = weight_variable([128, 64], name='W_fc2')\n", | |
" b_fc2 = bias_variable([64], name='b_fc2')\n", | |
" a_fc2 = tf.add(tf.matmul(h_fc1, W_fc2), b_fc2, name=\"zscore\")\n", | |
" h_fc2 = tf.nn.relu(a_fc2)\n", | |
"\n", | |
" W_fc3 = weight_variable([64, 10], name='W_fc3')\n", | |
" b_fc3 = bias_variable([10], name='b_fc3')\n", | |
" logits = tf.add(tf.matmul(h_fc2, W_fc3), b_fc3, name=\"logits\")\n", | |
" y_pred = tf.argmax(logits, 1, name='y_pred')\n", | |
"\n", | |
" return y_pred, logits\n", | |
"\n", | |
"\n", | |
"def weight_variable(shape, name):\n", | |
" \"\"\"weight_variable generates a weight variable of a given shape.\"\"\"\n", | |
" initial = tf.truncated_normal(shape, stddev=0.1)\n", | |
" return tf.Variable(initial, name)\n", | |
"\n", | |
"\n", | |
"def bias_variable(shape, name):\n", | |
" \"\"\"bias_variable generates a bias variable of a given shape.\"\"\"\n", | |
" initial = tf.constant(0.1, shape=shape)\n", | |
" return tf.Variable(initial, name)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Specify inputs, outputs, and a cost function" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-04-16T14:29:57.440031Z", | |
"start_time": "2018-04-16T14:29:57.291656Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"# Reset default graph\n", | |
"tf.reset_default_graph()\n", | |
"\n", | |
"# Create the model\n", | |
"x = tf.placeholder(tf.float32, [None, 784], name=\"x\")\n", | |
"\n", | |
"# Define loss and optimizer\n", | |
"y_ = tf.placeholder(tf.float32, [None, 10], name=\"y\")\n", | |
"\n", | |
"# Build the graph for the deep net\n", | |
"y_pred, logits = deepnn(x)\n", | |
"\n", | |
"with tf.name_scope(\"Loss\"):\n", | |
" cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, \n", | |
" logits=logits)\n", | |
" loss = tf.reduce_mean(cross_entropy, name=\"cross_entropy_loss\")\n", | |
"train_step = tf.train.AdamOptimizer(1e-4).minimize(loss, name=\"train_step\")\n", | |
" \n", | |
"# Here we specify the output as \"Prediction/y_pred\", this will be important later\n", | |
"with tf.name_scope(\"Prediction\"): \n", | |
" correct_prediction = tf.equal(y_pred, \n", | |
" tf.argmax(y_, 1))\n", | |
" accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name=\"accuracy\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Train the model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-04-16T14:30:24.491091Z", | |
"start_time": "2018-04-16T14:29:57.574591Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"sess = tf.Session()\n", | |
"# Initialize the variables (i.e. assign their default value)\n", | |
"sess.run(tf.global_variables_initializer())\n", | |
"saver = tf.train.Saver()\n", | |
"\n", | |
"for i in range(1, 20001):\n", | |
" batch_images, batch_labels = mnist.train.next_batch(batch_size)\n", | |
" feed_dict = {x: batch_images, y_: batch_labels}\n", | |
" if i % 1000 == 0:\n", | |
" train_accuracy = sess.run(accuracy, feed_dict=feed_dict)\n", | |
" print('step %d, training accuracy %g' % (i, train_accuracy))\n", | |
" sess.run(train_step, feed_dict=feed_dict)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## What is the final accuracy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-04-16T14:30:24.560096Z", | |
"start_time": "2018-04-16T14:30:24.493809Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"print('test accuracy %g' % sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Freeze the graph" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-04-16T14:30:24.614593Z", | |
"start_time": "2018-04-16T14:30:24.564158Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"saver.save(sess, \"./chkps/mnist_model\")\n", | |
"out_nodes = [y_pred.op.name]\n", | |
"print(out_nodes)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Remove unnecessary training nodes" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-04-16T14:30:36.629747Z", | |
"start_time": "2018-04-16T14:30:36.606952Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"sub_graph_def = remove_training_nodes(sess.graph_def)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Freeze Constants" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-04-16T14:30:37.568211Z", | |
"start_time": "2018-04-16T14:30:37.552980Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"sub_graph_def = gu.convert_variables_to_constants(sess, sub_graph_def, out_nodes)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Save the graph to PB file" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-04-16T14:30:46.068240Z", | |
"start_time": "2018-04-16T14:30:46.061283Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"graph_path = tf.train.write_graph(sub_graph_def,\n", | |
" \"./mnist_model\",\n", | |
" \"deep_mlp.pb\",\n", | |
" as_text=False)\n", | |
"\n", | |
"print('written graph to: %s' % graph_path)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-04-16T14:31:10.946225Z", | |
"start_time": "2018-04-16T14:31:10.942746Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"# close session\n", | |
"sess.close()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.14" | |
}, | |
"toc": { | |
"nav_menu": {}, | |
"number_sections": true, | |
"sideBar": true, | |
"skip_h1_title": false, | |
"toc_cell": false, | |
"toc_position": {}, | |
"toc_section_display": "block", | |
"toc_window_display": false | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment