Created
August 30, 2018 09:08
-
-
Save JanRuettinger/6ba8662c4b8df86213bfc2ec6ee426ca to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Gist to recreate a Batchnorm Layer bug. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/usr/lib/python3.5/importlib/_bootstrap.py:222: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n", | |
" return f(*args, **kwds)\n", | |
"/usr/lib/python3.5/importlib/_bootstrap.py:222: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n", | |
" return f(*args, **kwds)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Python version: sys.version_info(major=3, minor=5, micro=2, releaselevel='final', serial=0)\n", | |
"Tensorflow version: 1.9.0\n", | |
"Keras (tf) version: 2.1.6-tf\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/usr/lib/python3.5/importlib/_bootstrap.py:222: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n", | |
" return f(*args, **kwds)\n" | |
] | |
} | |
], | |
"source": [ | |
"# auto reload modules so that you changes are reloaded without restarting the notebook\n", | |
"%load_ext autoreload\n", | |
"%autoreload 2\n", | |
"\n", | |
"import os, sys, pathlib\n", | |
"import tensorflow as tf\n", | |
"import tensorflow.keras as keras\n", | |
"from tensorflow.keras.layers import Conv2D, Input, BatchNormalization \n", | |
"from tensorflow.keras.models import Model\n", | |
"\n", | |
"# Specify which GPU to use (need to be commented out when you want to use multiplte GPUs)\n", | |
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n", | |
"\n", | |
"# Makes Keras use only one GPU as oppossed to reserve all by default eben though they are not used for computation\n", | |
"config = tf.ConfigProto()\n", | |
"config.gpu_options.allow_growth = True\n", | |
"\n", | |
"print(\"Python version: {}\".format(sys.version_info))\n", | |
"print(\"Tensorflow version: {}\".format(tf.__version__))\n", | |
"print(\"Keras (tf) version: {}\".format(tf.keras.__version__))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def freeze_graph(model_dir, output_dir, output_node_names):\n", | |
" \"\"\"Extract the sub graph defined by the output nodes and convert \n", | |
" all its variables into constant \n", | |
" Args:\n", | |
" model_dir: the root folder containing the checkpoint state file\n", | |
" output_node_names: a string, containing all the output node's names, \n", | |
" comma separated\n", | |
" \"\"\"\n", | |
" if not tf.gfile.Exists(model_dir):\n", | |
" raise AssertionError(\n", | |
" \"Export directory doesn't exists. Please specify an export \"\n", | |
" \"directory: %s\" % model_dir)\n", | |
"\n", | |
" if not output_node_names:\n", | |
" print(\"You need to supply the name of a node to --output_node_names.\")\n", | |
" return -1\n", | |
"\n", | |
" # We retrieve our checkpoint file fullpath\n", | |
" checkpoint = tf.train.get_checkpoint_state(model_dir)\n", | |
" input_checkpoint = checkpoint.model_checkpoint_path\n", | |
" \n", | |
" # We precise the file fullname of our freezed graph\n", | |
" #absolute_model_dir = \"/\".join(input_checkpoint.split('/')[:-1])\n", | |
" output_graph = output_dir + \"/frozen_model.pb\"\n", | |
"\n", | |
" # We clear devices to allow TensorFlow to control on which device it will load operations\n", | |
" clear_devices = True\n", | |
"\n", | |
" # We start a session using a temporary fresh Graph\n", | |
" with tf.Session(graph=tf.Graph()) as sess:\n", | |
" # We import the meta graph in the current default Graph\n", | |
" saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)\n", | |
"\n", | |
" # We restore the weights\n", | |
" saver.restore(sess, input_checkpoint)\n", | |
"\n", | |
" gd = tf.get_default_graph().as_graph_def()\n", | |
"\n", | |
" # We use a built-in TF helper to export variables to constants\n", | |
" output_graph_def = tf.graph_util.convert_variables_to_constants(\n", | |
" sess, # The session is used to retrieve the weights\n", | |
" gd, # The graph_def is used to retrieve the nodes \n", | |
" output_node_names.split(\",\") # The output node names are used to select the usefull nodes\n", | |
" ) \n", | |
"\n", | |
" # Finally we serialize and dump the output graph to the filesystem\n", | |
" with tf.gfile.GFile(output_graph, \"wb\") as f:\n", | |
" f.write(output_graph_def.SerializeToString())\n", | |
" print(\"{} ops in the final graph.\".format(len(output_graph_def.node)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def load_frozen_graph(frozen_graph_filename):\n", | |
" # We load the protobuf file from the disk and parse it to retrieve the \n", | |
" # unserialized graph_def\n", | |
" with tf.gfile.GFile(frozen_graph_filename, \"rb\") as f:\n", | |
" graph_def = tf.GraphDef()\n", | |
" graph_def.ParseFromString(f.read())\n", | |
"\n", | |
" with tf.Graph().as_default() as graph:\n", | |
" tf.import_graph_def(graph_def, name='')\n", | |
"\n", | |
" return graph" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def build_bn(shape, n_classes):\n", | |
" inp = Input(shape=shape)\n", | |
" y = Conv2D(32, 3, strides=2, padding='same', activation='relu', name='conv')(inp)\n", | |
" y = BatchNormalization(name='bn')(y)\n", | |
" y_ = Conv2D(32, 3, padding='same', activation='relu', name='conv_final')(y)\n", | |
"\n", | |
" model = Model(inputs=inp, outputs=y_)\n", | |
" \n", | |
" return model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = build_bn(shape=(410,640,3), n_classes=7)\n", | |
"model.compile(loss='mean_squared_error', optimizer='sgd')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Model output name: Tensor(\"conv_final/Relu:0\", shape=(?, 205, 320, 32), dtype=float32)\n", | |
"You need the output name for the creation of the frozen graph\n", | |
"\n", | |
"Model input name: Tensor(\"input_1:0\", shape=(?, 410, 640, 3), dtype=float32)\n", | |
"You need the input name for the creation of the frozen graph\n", | |
"\n", | |
"Keras model was saved at /tmp/model_batchnorm_bug.h5\n", | |
"\n", | |
"Checkpoint files were saved at /tmp/checkpoint_files/model_batchnorm_bug.ckpt\n" | |
] | |
} | |
], | |
"source": [ | |
"# Save model as checkpoint files\n", | |
"\n", | |
"model_folder_path = '/tmp/'\n", | |
"\n", | |
"print(\"Model output name: {}\".format(model.output))\n", | |
"print(\"You need the output name for the creation of the frozen graph\")\n", | |
"print(\"\")\n", | |
"print(\"Model input name: {}\".format(model.input))\n", | |
"print(\"You need the input name for the creation of the frozen graph\")\n", | |
"print(\"\")\n", | |
"\n", | |
"# Save model with graph structure as hdf5 and checkpoint files\n", | |
"\n", | |
"# Save model as h5 file\n", | |
"# In order to load it you need to define the custom cost function again and pass it to the load function \n", | |
"model.save(model_folder_path + 'model_batchnorm_bug.h5')\n", | |
"print(\"Keras model was saved at {}\".format(model_folder_path + 'model_batchnorm_bug.h5'))\n", | |
"\n", | |
"# Save model as checkpoint files\n", | |
"sess=tf.keras.backend.get_session() \n", | |
"saver = tf.train.Saver()\n", | |
"# create directory if it doesn't exist yet\n", | |
"os.makedirs(model_folder_path + 'checkpoint_files/', exist_ok=True)\n", | |
"save_path = saver.save(sess, model_folder_path + 'checkpoint_files/model_batchnorm_bug.ckpt')\n", | |
"print(\"\")\n", | |
"print(\"Checkpoint files were saved at {}\".format(model_folder_path + 'checkpoint_files/model_batchnorm_bug.ckpt'))\n", | |
"\n", | |
"with open(model_folder_path + \"params.txt\", 'a') as file:\n", | |
" file.write(\"model_input:\" + str(model.input) + '\\n')\n", | |
" file.write(\"model_output:\" + str(model.output) + '\\n')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Restoring parameters from /tmp/checkpoint_files/model_batchnorm_bug.ckpt\n", | |
"Model was restored successfully!\n", | |
"Input shape: (?, 410, 640, 3)\n", | |
"Output shape: (?, 205, 320, 32)\n", | |
"INFO:tensorflow:Restoring parameters from /tmp/checkpoint_files/model_batchnorm_bug.ckpt\n", | |
"INFO:tensorflow:Froze 8 variables.\n", | |
"INFO:tensorflow:Converted 8 variables to const ops.\n", | |
"43 ops in the final graph.\n", | |
"Frozen graph of your model was created and stored at /tmp/frozen_model.pb\n" | |
] | |
} | |
], | |
"source": [ | |
"# Create frozen graph from checkpoint files\n", | |
"\n", | |
"tf.reset_default_graph() \n", | |
"sess = tf.Session(config=config)\n", | |
"\n", | |
"saver = tf.train.import_meta_graph('/tmp/checkpoint_files/model_batchnorm_bug.ckpt.meta', clear_devices=True)\n", | |
"\n", | |
"#saver.restore(sess,'./benchmarking/{}/checkpoint_files/{}.ckpt'.format(folder_name, model_name))\n", | |
"saver.restore(sess,tf.train.latest_checkpoint('/tmp/checkpoint_files/'))\n", | |
"print(\"Model was restored successfully!\")\n", | |
"\n", | |
"graph = tf.get_default_graph()\n", | |
"\n", | |
"output_name = 'conv_final/Relu:0'\n", | |
"input_name = \"input_1:0\"\n", | |
"\n", | |
"input_frozen_graph = graph.get_tensor_by_name(input_name)\n", | |
"output_frozen_graph = graph.get_tensor_by_name(output_name)\n", | |
"print(\"Input shape: {}\".format(input_frozen_graph.get_shape()))\n", | |
"print(\"Output shape: {}\".format(output_frozen_graph.get_shape()))\n", | |
"\n", | |
"\n", | |
"output_op = graph.get_operation_by_name(output_name[:-2])\n", | |
"#output = graph.get_tensor_by_name(output_name)\n", | |
"freeze_graph('/tmp/checkpoint_files','/tmp/', output_op.name)\n", | |
"\n", | |
"print(\"Frozen graph of your model was created and stored at {}\".format('/tmp/frozen_model.pb'))\n", | |
" \n", | |
"sess.close()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"ename": "ValueError", | |
"evalue": "Input 0 of node bn/cond/ReadVariableOp/Switch was passed float from bn/gamma:0 incompatible with expected resource.", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mInvalidArgumentError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m~/Student_Jan/virtual_environments/python3_5/lib/python3.5/site-packages/tensorflow/python/framework/importer.py\u001b[0m in \u001b[0;36mimport_graph_def\u001b[0;34m(graph_def, input_map, return_elements, name, op_dict, producer_op_list)\u001b[0m\n\u001b[1;32m 417\u001b[0m results = c_api.TF_GraphImportGraphDefWithResults(\n\u001b[0;32m--> 418\u001b[0;31m graph._c_graph, serialized, options) # pylint: disable=protected-access\n\u001b[0m\u001b[1;32m 419\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mc_api_util\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mScopedTFImportGraphDefResults\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mInvalidArgumentError\u001b[0m: Input 0 of node bn/cond/ReadVariableOp/Switch was passed float from bn/gamma:0 incompatible with expected resource.", | |
"\nDuring handling of the above exception, another exception occurred:\n", | |
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-8-774e38dda43d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;31m# We use our \"load_graph\" function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mgraph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_frozen_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfrozen_model_filepath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Frozen graph was loaded succesfully!\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m<ipython-input-3-a9e6aeaafa2a>\u001b[0m in \u001b[0;36mload_frozen_graph\u001b[0;34m(frozen_graph_filename)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mGraph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_default\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mgraph\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimport_graph_def\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph_def\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m''\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mgraph\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/Student_Jan/virtual_environments/python3_5/lib/python3.5/site-packages/tensorflow/python/util/deprecation.py\u001b[0m in \u001b[0;36mnew_func\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 430\u001b[0m \u001b[0;34m'in a future version'\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdate\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m'after %s'\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mdate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 431\u001b[0m instructions)\n\u001b[0;32m--> 432\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 433\u001b[0m return tf_decorator.make_decorator(func, new_func, 'deprecated',\n\u001b[1;32m 434\u001b[0m _add_deprecated_arg_notice_to_docstring(\n", | |
"\u001b[0;32m~/Student_Jan/virtual_environments/python3_5/lib/python3.5/site-packages/tensorflow/python/framework/importer.py\u001b[0m in \u001b[0;36mimport_graph_def\u001b[0;34m(graph_def, input_map, return_elements, name, op_dict, producer_op_list)\u001b[0m\n\u001b[1;32m 420\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mInvalidArgumentError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 421\u001b[0m \u001b[0;31m# Convert to ValueError for backwards compatibility.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 422\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 423\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 424\u001b[0m \u001b[0;31m# Create _DefinedFunctions for any imported functions.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mValueError\u001b[0m: Input 0 of node bn/cond/ReadVariableOp/Switch was passed float from bn/gamma:0 incompatible with expected resource." | |
] | |
} | |
], | |
"source": [ | |
"tf.reset_default_graph()\n", | |
"\n", | |
"# Load frozen graph\n", | |
"tf.reset_default_graph()\n", | |
"frozen_model_filepath = \"/tmp/frozen_model.pb\"\n", | |
"\n", | |
"# We use our \"load_graph\" function\n", | |
"graph = load_frozen_graph(frozen_model_filepath)\n", | |
"print(\"Frozen graph was loaded succesfully!\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment