Skip to content

Instantly share code, notes, and snippets.

@neil-tan
Created February 6, 2019 10:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save neil-tan/c132b7b82273c30038fe0da0e05b8f96 to your computer and use it in GitHub Desktop.
Save neil-tan/c132b7b82273c30038fe0da0e05b8f96 to your computer and use it in GitHub Desktop.
This file is located at uTensor/utensor-mnist-demo
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-04-16T14:29:56.593505Z",
"start_time": "2018-04-16T14:29:52.879458Z"
}
},
"outputs": [],
"source": [
"# This script is based on:\n",
"# https://www.tensorflow.org/get_started/mnist/pros\n",
"\n",
"import sys\n",
"import tensorflow as tf\n",
"from tensorflow.examples.tutorials.mnist import input_data\n",
"from tensorflow.python.framework import graph_util as gu\n",
"from tensorflow.python.framework.graph_util import remove_training_nodes\n",
"from tensorflow.tools.graph_transforms import TransformGraph"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import training data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-04-16T14:29:57.272673Z",
"start_time": "2018-04-16T14:29:56.595420Z"
}
},
"outputs": [],
"source": [
"mnist = input_data.read_data_sets(\"mnist_data/\", one_hot=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Define Tensorflow Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-04-16T14:29:57.278092Z",
"start_time": "2018-04-16T14:29:57.274735Z"
}
},
"outputs": [],
"source": [
"batch_size = 50"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Fully connected 2 layer NN"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-04-16T14:29:57.288994Z",
"start_time": "2018-04-16T14:29:57.280660Z"
}
},
"outputs": [],
"source": [
"def deepnn(x):\n",
" W_fc1 = weight_variable([784, 128], name='W_fc1')\n",
" b_fc1 = bias_variable([128], name='b_fc1')\n",
" a_fc1 = tf.add(tf.matmul(x, W_fc1), b_fc1, name=\"zscore\")\n",
" h_fc1 = tf.nn.relu(a_fc1)\n",
"\n",
" W_fc2 = weight_variable([128, 64], name='W_fc2')\n",
" b_fc2 = bias_variable([64], name='b_fc2')\n",
" a_fc2 = tf.add(tf.matmul(h_fc1, W_fc2), b_fc2, name=\"zscore\")\n",
" h_fc2 = tf.nn.relu(a_fc2)\n",
"\n",
" W_fc3 = weight_variable([64, 10], name='W_fc3')\n",
" b_fc3 = bias_variable([10], name='b_fc3')\n",
" logits = tf.add(tf.matmul(h_fc2, W_fc3), b_fc3, name=\"logits\")\n",
" y_pred = tf.argmax(logits, 1, name='y_pred')\n",
"\n",
" return y_pred, logits\n",
"\n",
"\n",
"def weight_variable(shape, name):\n",
" \"\"\"weight_variable generates a weight variable of a given shape.\"\"\"\n",
" initial = tf.truncated_normal(shape, stddev=0.1)\n",
" return tf.Variable(initial, name)\n",
"\n",
"\n",
"def bias_variable(shape, name):\n",
" \"\"\"bias_variable generates a bias variable of a given shape.\"\"\"\n",
" initial = tf.constant(0.1, shape=shape)\n",
" return tf.Variable(initial, name)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Specify inputs, outputs, and a cost function"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-04-16T14:29:57.440031Z",
"start_time": "2018-04-16T14:29:57.291656Z"
}
},
"outputs": [],
"source": [
"# Reset default graph\n",
"tf.reset_default_graph()\n",
"\n",
"# Create the model\n",
"x = tf.placeholder(tf.float32, [None, 784], name=\"x\")\n",
"\n",
"# Define loss and optimizer\n",
"y_ = tf.placeholder(tf.float32, [None, 10], name=\"y\")\n",
"\n",
"# Build the graph for the deep net\n",
"y_pred, logits = deepnn(x)\n",
"\n",
"with tf.name_scope(\"Loss\"):\n",
" cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, \n",
" logits=logits)\n",
" loss = tf.reduce_mean(cross_entropy, name=\"cross_entropy_loss\")\n",
"train_step = tf.train.AdamOptimizer(1e-4).minimize(loss, name=\"train_step\")\n",
" \n",
"# Here we specify the output as \"Prediction/y_pred\", this will be important later\n",
"with tf.name_scope(\"Prediction\"): \n",
" correct_prediction = tf.equal(y_pred, \n",
" tf.argmax(y_, 1))\n",
" accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name=\"accuracy\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Train the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-04-16T14:30:24.491091Z",
"start_time": "2018-04-16T14:29:57.574591Z"
}
},
"outputs": [],
"source": [
"sess = tf.Session()\n",
"# Initialize the variables (i.e. assign their default value)\n",
"sess.run(tf.global_variables_initializer())\n",
"saver = tf.train.Saver()\n",
"\n",
"for i in range(1, 20001):\n",
" batch_images, batch_labels = mnist.train.next_batch(batch_size)\n",
" feed_dict = {x: batch_images, y_: batch_labels}\n",
" if i % 1000 == 0:\n",
" train_accuracy = sess.run(accuracy, feed_dict=feed_dict)\n",
" print('step %d, training accuracy %g' % (i, train_accuracy))\n",
" sess.run(train_step, feed_dict=feed_dict)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## What is the final accuracy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-04-16T14:30:24.560096Z",
"start_time": "2018-04-16T14:30:24.493809Z"
}
},
"outputs": [],
"source": [
"print('test accuracy %g' % sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Freeze the graph"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-04-16T14:30:24.614593Z",
"start_time": "2018-04-16T14:30:24.564158Z"
}
},
"outputs": [],
"source": [
"saver.save(sess, \"./chkps/mnist_model\")\n",
"out_nodes = [y_pred.op.name]\n",
"print(out_nodes)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Remove unnecessary training nodes"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-04-16T14:30:36.629747Z",
"start_time": "2018-04-16T14:30:36.606952Z"
}
},
"outputs": [],
"source": [
"sub_graph_def = remove_training_nodes(sess.graph_def)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Freeze Constants"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-04-16T14:30:37.568211Z",
"start_time": "2018-04-16T14:30:37.552980Z"
}
},
"outputs": [],
"source": [
"sub_graph_def = gu.convert_variables_to_constants(sess, sub_graph_def, out_nodes)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Save the graph to PB file"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-04-16T14:30:46.068240Z",
"start_time": "2018-04-16T14:30:46.061283Z"
}
},
"outputs": [],
"source": [
"graph_path = tf.train.write_graph(sub_graph_def,\n",
" \"./mnist_model\",\n",
" \"deep_mlp.pb\",\n",
" as_text=False)\n",
"\n",
"print('written graph to: %s' % graph_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-04-16T14:31:10.946225Z",
"start_time": "2018-04-16T14:31:10.942746Z"
}
},
"outputs": [],
"source": [
"# close session\n",
"sess.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.14"
},
"toc": {
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"toc_cell": false,
"toc_position": {},
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment