Skip to content

Instantly share code, notes, and snippets.

@nulledge
Last active November 28, 2017 15:00
Show Gist options
  • Save nulledge/68a6b9a27a2e8c140b3cb1b01caa2df2 to your computer and use it in GitHub Desktop.
Save nulledge/68a6b9a27a2e8c140b3cb1b01caa2df2 to your computer and use it in GitHub Desktop.
MNIST toy project run on Quadro m6000.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/nulledge/pyenv/tf-36/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6\n",
" return f(*args, **kwds)\n"
]
}
],
"source": [
"'''Load required modules.\n",
"Modules:\n",
" tqdm: A visualizing tool for loop.\n",
" tensorflow: A framework for machine learning.\n",
" numpy: An array utility.\n",
"'''\n",
"from tqdm import tqdm as tqdm\n",
"import tensorflow as tf\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting /home/nulledge/data/MNIST/train-images-idx3-ubyte.gz\n",
"Extracting /home/nulledge/data/MNIST/train-labels-idx1-ubyte.gz\n",
"Extracting /home/nulledge/data/MNIST/t10k-images-idx3-ubyte.gz\n",
"Extracting /home/nulledge/data/MNIST/t10k-labels-idx1-ubyte.gz\n"
]
}
],
"source": [
"'''Trainig data from TensorFlow.\n",
"'''\n",
"from tensorflow.examples.tutorials.mnist import input_data as mnist_data\n",
"mnist = mnist_data.read_data_sets(\"/home/nulledge/data/MNIST/\", one_hot=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"'''Define flags.\n",
"'''\n",
"flags = tf.app.flags\n",
"flags.DEFINE_integer('epoches', 10, 'The number of train epoches.')\n",
"flags.DEFINE_integer('mnist_train', 10000, 'The number of train dataset.')\n",
"flags.DEFINE_integer('mnist_test', 5000, 'The number of test dataset.')\n",
"flags.DEFINE_integer('batch', 25, 'The batch size.')\n",
"flags.DEFINE_integer('labels', 10, 'The number of labels.')\n",
"flags.DEFINE_integer('resolution', 28*28, 'The resolution of input image in flatten shape.')\n",
"flags.DEFINE_boolean('phase', False, 'Whether train mode or not.')\n",
"flags.DEFINE_string('checkpoint', '/home/nulledge/ckpt/MNIST/mnist.ckpt', 'The path of checkpoint.')\n",
"flags.DEFINE_string('summary', '/home/nulledge/log/MNIST/', 'The path of log.')\n",
"\n",
"FLAGS = flags.FLAGS"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"'''Placeholders\n",
"\n",
"Tensors:\n",
" images: The input images of MNIST in flatten shape.\n",
" labels_gt: The groundtruth labels of MNIST in one-hot encoding.\n",
" phase: The boolean tensor for batch-normalization is_training parameter.\n",
"'''\n",
"with tf.variable_scope('placeholder'):\n",
" images = tf.placeholder(\n",
" name = 'images',\n",
" shape = [None, FLAGS.resolution],\n",
" dtype = tf.float32\n",
" )\n",
" labels_gt = tf.placeholder(\n",
" name = 'labels_gt',\n",
" shape = [None, FLAGS.labels],\n",
" dtype = tf.float32\n",
" )\n",
" phase = tf.placeholder(\n",
" name = 'train',\n",
" shape = (),\n",
" dtype = tf.bool\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"'''Build the network.\n",
"\n",
"Input:\n",
" images: The input tensor from the placeholder scope.\n",
"\n",
"Structure:\n",
" layer_01\n",
" fully_connected\n",
" batch_normalization\n",
" relu\n",
" layer_02\n",
" fully_connected\n",
" batch_normalization\n",
" relu\n",
"\n",
"Output:\n",
" logits: The logits to be through softmax in shape [None, 10]\n",
"'''\n",
"net = images\n",
"\n",
"with tf.variable_scope('layer_01'):\n",
" net = tf.contrib.layers.fully_connected(net, 100, activation_fn = None, scope = 'fc')\n",
" net = tf.contrib.layers.batch_norm(net, center = True, scale = True, is_training = phase, scope = 'bn')\n",
" net = tf.nn.relu(net, 'relu')\n",
" \n",
"with tf.variable_scope('layer_02'):\n",
" net = tf.contrib.layers.fully_connected(net, 100, activation_fn = None, scope = 'fc')\n",
" net = tf.contrib.layers.batch_norm(net, center = True, scale = True, is_training = phase, scope = 'bn')\n",
" net = tf.nn.relu(net, 'relu')\n",
"\n",
"with tf.variable_scope('predict'):\n",
" logits = tf.contrib.layers.fully_connected(net, FLAGS.labels, activation_fn = None, scope = 'logits')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"'''Build the optional network.\n",
"\n",
"Tensor:\n",
" accuracy: The percentage of right classifications.\n",
" loss: The cross-entropy between groundtruth and predicted labels.\n",
" optimizer: The optimizer over loss.\n",
" \n",
" summary_accuracy:\n",
" summary_loss:\n",
" summary_merged: The log to be saved.\n",
"'''\n",
"with tf.name_scope('accuracy'):\n",
" accuracy = tf.reduce_mean(\n",
" tf.cast(\n",
" tf.equal(\n",
" tf.argmax(labels_gt, 1),\n",
" tf.argmax(logits, 1)\n",
" ),\n",
" dtype = tf.float32\n",
" )\n",
" )\n",
" summary_accuracy = tf.summary.scalar('accuracy', accuracy)\n",
" \n",
"with tf.name_scope('loss'):\n",
" loss = tf.reduce_mean(\n",
" tf.nn.softmax_cross_entropy_with_logits(\n",
" logits = logits,\n",
" labels = labels_gt\n",
" )\n",
" )\n",
" summary_loss = tf.summary.scalar('loss', loss)\n",
" \n",
"summary_merged = tf.summary.merge_all()\n",
" \n",
"update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)\n",
"with tf.control_dependencies(update_ops):\n",
" optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(loss)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Restoring parameters from /home/nulledge/ckpt/MNIST/mnist.ckpt\n",
"Load failed. Initialize variables.\n"
]
}
],
"source": [
"'''Build graph.\n",
"\n",
"Open session and load from FLAGS.checkpoint. If failed to load then initialize\n",
"all variables in the graph.\n",
"\n",
"Tensors:\n",
" sess: The session to interact.\n",
" saver: The saver which saves and loads the checkpoint in FLAGS.checkpoint.\n",
" writer: The log file writer.\n",
"'''\n",
"sess = tf.InteractiveSession()\n",
"saver = tf.train.Saver()\n",
"writer = tf.summary.FileWriter(FLAGS.summary, sess.graph)\n",
"\n",
"try:\n",
" saver.restore(sess, FLAGS.checkpoint)\n",
"except:\n",
" print('Load failed. Initialize variables.')\n",
" sess.run(tf.global_variables_initializer())"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"epoch(1/10): 100%|██████████| 10000/10000 [00:01<00:00, 5812.23it/s, accuracy=0.96]\n",
"epoch(2/10): 100%|██████████| 10000/10000 [00:01<00:00, 9009.94it/s, accuracy=0.88]\n",
"epoch(3/10): 100%|██████████| 10000/10000 [00:01<00:00, 8882.72it/s, accuracy=0.92]\n",
"epoch(4/10): 100%|██████████| 10000/10000 [00:01<00:00, 9003.28it/s, accuracy=0.92]\n",
"epoch(5/10): 100%|██████████| 10000/10000 [00:01<00:00, 8941.69it/s, accuracy=0.92]\n",
"epoch(6/10): 100%|██████████| 10000/10000 [00:01<00:00, 8502.59it/s, accuracy=1] \n",
"epoch(7/10): 100%|██████████| 10000/10000 [00:01<00:00, 9013.86it/s, accuracy=0.92]\n",
"epoch(8/10): 100%|██████████| 10000/10000 [00:01<00:00, 8821.30it/s, accuracy=0.96]\n",
"epoch(9/10): 100%|██████████| 10000/10000 [00:01<00:00, 9028.97it/s, accuracy=0.96]\n",
"epoch(10/10): 100%|██████████| 10000/10000 [00:01<00:00, 9018.71it/s, accuracy=0.92]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"save path: /home/nulledge/ckpt/MNIST/mnist.ckpt\n"
]
}
],
"source": [
"'''Train and save the network.\n",
"'''\n",
"idx = 0\n",
"for epoch in range(FLAGS.epoches):\n",
" train_iterator = tqdm(total = FLAGS.mnist_train)\n",
" train_iterator.set_description('epoch(' + str(epoch+1) + '/' + str(FLAGS.epoches) + ')')\n",
" for _ in range(FLAGS.mnist_train // FLAGS.batch):\n",
" train_images, train_labels = mnist.train.next_batch(FLAGS.batch)\n",
" _, train_accuracy, train_summary = sess.run([optimizer, accuracy, summary_merged],\n",
" feed_dict = {\n",
" images: train_images,\n",
" labels_gt: train_labels,\n",
" phase: True\n",
" })\n",
" train_iterator.set_postfix(accuracy = train_accuracy)\n",
" train_iterator.update(FLAGS.batch)\n",
" writer.add_summary(train_summary, idx)\n",
" idx += 1\n",
" train_iterator.close()\n",
"\n",
"saved_path = saver.save(sess, FLAGS.checkpoint)\n",
"print('save path:', saved_path)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"test: 100%|██████████| 5000/5000 [00:00<00:00, 28902.36it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mean accuracy: 0.9572\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"'''Test the network.\n",
"'''\n",
"\n",
"test_result = []\n",
"test_iterator = tqdm(total = FLAGS.mnist_test, desc = 'test')\n",
"for index in range(FLAGS.mnist_test // FLAGS.batch):\n",
" test_images, test_labels = mnist.test.next_batch(FLAGS.batch)\n",
" test_accuracy = sess.run([accuracy],\n",
" feed_dict = {\n",
" images: test_images,\n",
" labels_gt: test_labels,\n",
" phase: False\n",
" })\n",
" test_iterator.update(FLAGS.batch)\n",
" test_result.append(test_accuracy)\n",
"test_iterator.close()\n",
"\n",
"print('mean accuracy:', np.mean(test_result))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "tf-36",
"language": "python",
"name": "th-36"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@nulledge
Copy link
Author

tensorboard
tensorboard result.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment