Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save nikashitsa/75057179e244d5ca085beaa1b9c12a5d to your computer and use it in GitHub Desktop.
Save nikashitsa/75057179e244d5ca085beaa1b9c12a5d to your computer and use it in GitHub Desktop.
How to remove dropout from frozen model
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "How to remove dropout from frozen model\n==="
},
{
"metadata": {
"collapsed": false,
"trusted": true
},
"cell_type": "code",
"source": "from __future__ import print_function\nfrom tensorflow.core.framework import graph_pb2\nimport tensorflow as tf\nimport numpy as np\nfrom tensorflow.examples.tutorials.mnist import input_data\n\nmnist = input_data.read_data_sets('/tmp/data/', one_hot=True)\n\ndef display_nodes(nodes):\n for i, node in enumerate(nodes):\n print('%d %s %s' % (i, node.name, node.op))\n [print(u'└─── %d ─ %s' % (i, n)) for i, n in enumerate(node.input)]\n \ndef accuracy(predictions, labels):\n return (100.0 * np.sum(np.argmax(predictions, 1) == np.argmax(labels, 1)) / predictions.shape[0])\n\ndef test_graph(graph_path, use_dropout):\n tf.reset_default_graph()\n graph_def = tf.GraphDef()\n \n with tf.gfile.FastGFile(graph_path, 'rb') as f:\n graph_def.ParseFromString(f.read())\n \n _ = tf.import_graph_def(graph_def, name='')\n sess = tf.Session() \n prediction_tensor = sess.graph.get_tensor_by_name('final_result:0') \n \n feed_dict = {'input:0': mnist.test.images[:256]}\n if use_dropout:\n feed_dict['keep_prob:0'] = 1.0\n \n predictions = sess.run(prediction_tensor, feed_dict)\n result = accuracy(predictions, mnist.test.labels[:256])\n return result",
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "Extracting /tmp/data/train-images-idx3-ubyte.gz\nExtracting /tmp/data/train-labels-idx1-ubyte.gz\nExtracting /tmp/data/t10k-images-idx3-ubyte.gz\nExtracting /tmp/data/t10k-labels-idx1-ubyte.gz\n"
}
]
},
{
"metadata": {
"collapsed": false,
"trusted": true
},
"cell_type": "code",
"source": "# read frozen graph and display nodes\ngraph = tf.GraphDef()\nwith tf.gfile.Open('./frozen_model.pb', 'r') as f:\n data = f.read()\n graph.ParseFromString(data)\n \ndisplay_nodes(graph.node)",
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "0 input Placeholder\n1 keep_prob Placeholder\n2 Variable Const\n3 Variable/read Identity\n└─── 0 ─ Variable\n4 Variable_1 Const\n5 Variable_1/read Identity\n└─── 0 ─ Variable_1\n6 Variable_2 Const\n7 Variable_2/read Identity\n└─── 0 ─ Variable_2\n8 Variable_3 Const\n9 Variable_3/read Identity\n└─── 0 ─ Variable_3\n10 Variable_4 Const\n11 Variable_4/read Identity\n└─── 0 ─ Variable_4\n12 Variable_5 Const\n13 Variable_5/read Identity\n└─── 0 ─ Variable_5\n14 Variable_6 Const\n15 Variable_6/read Identity\n└─── 0 ─ Variable_6\n16 Variable_7 Const\n17 Variable_7/read Identity\n└─── 0 ─ Variable_7\n18 Reshape/shape Const\n19 Reshape Reshape\n└─── 0 ─ input\n└─── 1 ─ Reshape/shape\n20 Conv2D Conv2D\n└─── 0 ─ Reshape\n└─── 1 ─ Variable/read\n21 BiasAdd BiasAdd\n└─── 0 ─ Conv2D\n└─── 1 ─ Variable_4/read\n22 Relu Relu\n└─── 0 ─ BiasAdd\n23 MaxPool MaxPool\n└─── 0 ─ Relu\n24 Conv2D_1 Conv2D\n└─── 0 ─ MaxPool\n└─── 1 ─ Variable_1/read\n25 BiasAdd_1 BiasAdd\n└─── 0 ─ Conv2D_1\n└─── 1 ─ Variable_5/read\n26 Relu_1 Relu\n└─── 0 ─ BiasAdd_1\n27 MaxPool_1 MaxPool\n└─── 0 ─ Relu_1\n28 Reshape_1/shape Const\n29 Reshape_1 Reshape\n└─── 0 ─ MaxPool_1\n└─── 1 ─ Reshape_1/shape\n30 MatMul MatMul\n└─── 0 ─ Reshape_1\n└─── 1 ─ Variable_2/read\n31 Add Add\n└─── 0 ─ MatMul\n└─── 1 ─ Variable_6/read\n32 Relu_2 Relu\n└─── 0 ─ Add\n33 dropout/Shape Shape\n└─── 0 ─ Relu_2\n34 dropout/random_uniform/min Const\n35 dropout/random_uniform/max Const\n36 dropout/random_uniform/RandomUniform RandomUniform\n└─── 0 ─ dropout/Shape\n37 dropout/random_uniform/sub Sub\n└─── 0 ─ dropout/random_uniform/max\n└─── 1 ─ dropout/random_uniform/min\n38 dropout/random_uniform/mul Mul\n└─── 0 ─ dropout/random_uniform/RandomUniform\n└─── 1 ─ dropout/random_uniform/sub\n39 dropout/random_uniform Add\n└─── 0 ─ dropout/random_uniform/mul\n└─── 1 ─ dropout/random_uniform/min\n40 dropout/add Add\n└─── 0 ─ keep_prob\n└─── 1 ─ dropout/random_uniform\n41 dropout/Floor Floor\n└─── 0 ─ dropout/add\n42 dropout/Div Div\n└─── 0 ─ Relu_2\n└─── 1 ─ keep_prob\n43 dropout/mul Mul\n└─── 0 ─ dropout/Div\n└─── 1 ─ dropout/Floor\n44 MatMul_1 MatMul\n└─── 0 ─ dropout/mul\n└─── 1 ─ Variable_3/read\n45 Add_1 Add\n└─── 0 ─ MatMul_1\n└─── 1 ─ Variable_7/read\n46 final_result Softmax\n└─── 0 ─ Add_1\n"
}
]
},
{
"metadata": {
"collapsed": true,
"trusted": true
},
"cell_type": "code",
"source": "# Connect 'MatMul_1' with 'Relu_2'\ngraph.node[44].input[0] = 'Relu_2' # 44 -> MatMul_1\n# Remove dropout nodes\nnodes = graph.node[:33] + graph.node[44:] # 33 -> MatMul_1 \ndel nodes[1] # 1 -> keep_prob\n\n# Save graph\noutput_graph = graph_pb2.GraphDef()\noutput_graph.node.extend(nodes)\nwith tf.gfile.GFile('./frozen_model_without_dropout.pb', 'w') as f:\n f.write(output_graph.SerializeToString())",
"execution_count": 3,
"outputs": []
},
{
"metadata": {
"collapsed": false,
"trusted": true
},
"cell_type": "code",
"source": "# test graph via simple test\nresult_1 = test_graph('./frozen_model.pb', use_dropout=True)\nresult_2 = test_graph('./frozen_model_without_dropout.pb', use_dropout=False)\n\nprint('with dropout: %f' % result_1)\nprint('without dropout: %f' % result_2)",
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "with dropout: 80.859375\nwithout dropout: 80.859375\n"
}
]
},
{
"metadata": {
"collapsed": true,
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "python2",
"display_name": "Python 2",
"language": "python"
},
"language_info": {
"mimetype": "text/x-python",
"nbconvert_exporter": "python",
"name": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12",
"file_extension": ".py",
"codemirror_mode": {
"version": 2,
"name": "ipython"
}
},
"gist": {
"id": "4498bb2174d85104c4396d3f48a0a09d",
"data": {
"description": "How to remove dropout from frozen model",
"public": true
}
},
"_draft": {
"nbviewer_url": "https://gist.github.com/4498bb2174d85104c4396d3f48a0a09d"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment