Skip to content

Instantly share code, notes, and snippets.

@JanRuettinger
Created August 30, 2018 09:08
Show Gist options
  • Save JanRuettinger/6ba8662c4b8df86213bfc2ec6ee426ca to your computer and use it in GitHub Desktop.
Save JanRuettinger/6ba8662c4b8df86213bfc2ec6ee426ca to your computer and use it in GitHub Desktop.
Gist to recreate a Batchnorm Layer bug.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/lib/python3.5/importlib/_bootstrap.py:222: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
" return f(*args, **kwds)\n",
"/usr/lib/python3.5/importlib/_bootstrap.py:222: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
" return f(*args, **kwds)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Python version: sys.version_info(major=3, minor=5, micro=2, releaselevel='final', serial=0)\n",
"Tensorflow version: 1.9.0\n",
"Keras (tf) version: 2.1.6-tf\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/lib/python3.5/importlib/_bootstrap.py:222: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
" return f(*args, **kwds)\n"
]
}
],
"source": [
"# auto reload modules so that you changes are reloaded without restarting the notebook\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import os, sys, pathlib\n",
"import tensorflow as tf\n",
"import tensorflow.keras as keras\n",
"from tensorflow.keras.layers import Conv2D, Input, BatchNormalization \n",
"from tensorflow.keras.models import Model\n",
"\n",
"# Specify which GPU to use (need to be commented out when you want to use multiplte GPUs)\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n",
"\n",
"# Makes Keras use only one GPU as oppossed to reserve all by default eben though they are not used for computation\n",
"config = tf.ConfigProto()\n",
"config.gpu_options.allow_growth = True\n",
"\n",
"print(\"Python version: {}\".format(sys.version_info))\n",
"print(\"Tensorflow version: {}\".format(tf.__version__))\n",
"print(\"Keras (tf) version: {}\".format(tf.keras.__version__))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def freeze_graph(model_dir, output_dir, output_node_names):\n",
" \"\"\"Extract the sub graph defined by the output nodes and convert \n",
" all its variables into constant \n",
" Args:\n",
" model_dir: the root folder containing the checkpoint state file\n",
" output_node_names: a string, containing all the output node's names, \n",
" comma separated\n",
" \"\"\"\n",
" if not tf.gfile.Exists(model_dir):\n",
" raise AssertionError(\n",
" \"Export directory doesn't exists. Please specify an export \"\n",
" \"directory: %s\" % model_dir)\n",
"\n",
" if not output_node_names:\n",
" print(\"You need to supply the name of a node to --output_node_names.\")\n",
" return -1\n",
"\n",
" # We retrieve our checkpoint file fullpath\n",
" checkpoint = tf.train.get_checkpoint_state(model_dir)\n",
" input_checkpoint = checkpoint.model_checkpoint_path\n",
" \n",
" # We precise the file fullname of our freezed graph\n",
" #absolute_model_dir = \"/\".join(input_checkpoint.split('/')[:-1])\n",
" output_graph = output_dir + \"/frozen_model.pb\"\n",
"\n",
" # We clear devices to allow TensorFlow to control on which device it will load operations\n",
" clear_devices = True\n",
"\n",
" # We start a session using a temporary fresh Graph\n",
" with tf.Session(graph=tf.Graph()) as sess:\n",
" # We import the meta graph in the current default Graph\n",
" saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)\n",
"\n",
" # We restore the weights\n",
" saver.restore(sess, input_checkpoint)\n",
"\n",
" gd = tf.get_default_graph().as_graph_def()\n",
"\n",
" # We use a built-in TF helper to export variables to constants\n",
" output_graph_def = tf.graph_util.convert_variables_to_constants(\n",
" sess, # The session is used to retrieve the weights\n",
" gd, # The graph_def is used to retrieve the nodes \n",
" output_node_names.split(\",\") # The output node names are used to select the usefull nodes\n",
" ) \n",
"\n",
" # Finally we serialize and dump the output graph to the filesystem\n",
" with tf.gfile.GFile(output_graph, \"wb\") as f:\n",
" f.write(output_graph_def.SerializeToString())\n",
" print(\"{} ops in the final graph.\".format(len(output_graph_def.node)))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def load_frozen_graph(frozen_graph_filename):\n",
" # We load the protobuf file from the disk and parse it to retrieve the \n",
" # unserialized graph_def\n",
" with tf.gfile.GFile(frozen_graph_filename, \"rb\") as f:\n",
" graph_def = tf.GraphDef()\n",
" graph_def.ParseFromString(f.read())\n",
"\n",
" with tf.Graph().as_default() as graph:\n",
" tf.import_graph_def(graph_def, name='')\n",
"\n",
" return graph"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def build_bn(shape, n_classes):\n",
" inp = Input(shape=shape)\n",
" y = Conv2D(32, 3, strides=2, padding='same', activation='relu', name='conv')(inp)\n",
" y = BatchNormalization(name='bn')(y)\n",
" y_ = Conv2D(32, 3, padding='same', activation='relu', name='conv_final')(y)\n",
"\n",
" model = Model(inputs=inp, outputs=y_)\n",
" \n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"model = build_bn(shape=(410,640,3), n_classes=7)\n",
"model.compile(loss='mean_squared_error', optimizer='sgd')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model output name: Tensor(\"conv_final/Relu:0\", shape=(?, 205, 320, 32), dtype=float32)\n",
"You need the output name for the creation of the frozen graph\n",
"\n",
"Model input name: Tensor(\"input_1:0\", shape=(?, 410, 640, 3), dtype=float32)\n",
"You need the input name for the creation of the frozen graph\n",
"\n",
"Keras model was saved at /tmp/model_batchnorm_bug.h5\n",
"\n",
"Checkpoint files were saved at /tmp/checkpoint_files/model_batchnorm_bug.ckpt\n"
]
}
],
"source": [
"# Save model as checkpoint files\n",
"\n",
"model_folder_path = '/tmp/'\n",
"\n",
"print(\"Model output name: {}\".format(model.output))\n",
"print(\"You need the output name for the creation of the frozen graph\")\n",
"print(\"\")\n",
"print(\"Model input name: {}\".format(model.input))\n",
"print(\"You need the input name for the creation of the frozen graph\")\n",
"print(\"\")\n",
"\n",
"# Save model with graph structure as hdf5 and checkpoint files\n",
"\n",
"# Save model as h5 file\n",
"# In order to load it you need to define the custom cost function again and pass it to the load function \n",
"model.save(model_folder_path + 'model_batchnorm_bug.h5')\n",
"print(\"Keras model was saved at {}\".format(model_folder_path + 'model_batchnorm_bug.h5'))\n",
"\n",
"# Save model as checkpoint files\n",
"sess=tf.keras.backend.get_session() \n",
"saver = tf.train.Saver()\n",
"# create directory if it doesn't exist yet\n",
"os.makedirs(model_folder_path + 'checkpoint_files/', exist_ok=True)\n",
"save_path = saver.save(sess, model_folder_path + 'checkpoint_files/model_batchnorm_bug.ckpt')\n",
"print(\"\")\n",
"print(\"Checkpoint files were saved at {}\".format(model_folder_path + 'checkpoint_files/model_batchnorm_bug.ckpt'))\n",
"\n",
"with open(model_folder_path + \"params.txt\", 'a') as file:\n",
" file.write(\"model_input:\" + str(model.input) + '\\n')\n",
" file.write(\"model_output:\" + str(model.output) + '\\n')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Restoring parameters from /tmp/checkpoint_files/model_batchnorm_bug.ckpt\n",
"Model was restored successfully!\n",
"Input shape: (?, 410, 640, 3)\n",
"Output shape: (?, 205, 320, 32)\n",
"INFO:tensorflow:Restoring parameters from /tmp/checkpoint_files/model_batchnorm_bug.ckpt\n",
"INFO:tensorflow:Froze 8 variables.\n",
"INFO:tensorflow:Converted 8 variables to const ops.\n",
"43 ops in the final graph.\n",
"Frozen graph of your model was created and stored at /tmp/frozen_model.pb\n"
]
}
],
"source": [
"# Create frozen graph from checkpoint files\n",
"\n",
"tf.reset_default_graph() \n",
"sess = tf.Session(config=config)\n",
"\n",
"saver = tf.train.import_meta_graph('/tmp/checkpoint_files/model_batchnorm_bug.ckpt.meta', clear_devices=True)\n",
"\n",
"#saver.restore(sess,'./benchmarking/{}/checkpoint_files/{}.ckpt'.format(folder_name, model_name))\n",
"saver.restore(sess,tf.train.latest_checkpoint('/tmp/checkpoint_files/'))\n",
"print(\"Model was restored successfully!\")\n",
"\n",
"graph = tf.get_default_graph()\n",
"\n",
"output_name = 'conv_final/Relu:0'\n",
"input_name = \"input_1:0\"\n",
"\n",
"input_frozen_graph = graph.get_tensor_by_name(input_name)\n",
"output_frozen_graph = graph.get_tensor_by_name(output_name)\n",
"print(\"Input shape: {}\".format(input_frozen_graph.get_shape()))\n",
"print(\"Output shape: {}\".format(output_frozen_graph.get_shape()))\n",
"\n",
"\n",
"output_op = graph.get_operation_by_name(output_name[:-2])\n",
"#output = graph.get_tensor_by_name(output_name)\n",
"freeze_graph('/tmp/checkpoint_files','/tmp/', output_op.name)\n",
"\n",
"print(\"Frozen graph of your model was created and stored at {}\".format('/tmp/frozen_model.pb'))\n",
" \n",
"sess.close()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"ename": "ValueError",
"evalue": "Input 0 of node bn/cond/ReadVariableOp/Switch was passed float from bn/gamma:0 incompatible with expected resource.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mInvalidArgumentError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m~/Student_Jan/virtual_environments/python3_5/lib/python3.5/site-packages/tensorflow/python/framework/importer.py\u001b[0m in \u001b[0;36mimport_graph_def\u001b[0;34m(graph_def, input_map, return_elements, name, op_dict, producer_op_list)\u001b[0m\n\u001b[1;32m 417\u001b[0m results = c_api.TF_GraphImportGraphDefWithResults(\n\u001b[0;32m--> 418\u001b[0;31m graph._c_graph, serialized, options) # pylint: disable=protected-access\n\u001b[0m\u001b[1;32m 419\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mc_api_util\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mScopedTFImportGraphDefResults\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mInvalidArgumentError\u001b[0m: Input 0 of node bn/cond/ReadVariableOp/Switch was passed float from bn/gamma:0 incompatible with expected resource.",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-8-774e38dda43d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;31m# We use our \"load_graph\" function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mgraph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_frozen_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfrozen_model_filepath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Frozen graph was loaded succesfully!\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-3-a9e6aeaafa2a>\u001b[0m in \u001b[0;36mload_frozen_graph\u001b[0;34m(frozen_graph_filename)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mGraph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_default\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mgraph\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimport_graph_def\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph_def\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m''\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mgraph\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Student_Jan/virtual_environments/python3_5/lib/python3.5/site-packages/tensorflow/python/util/deprecation.py\u001b[0m in \u001b[0;36mnew_func\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 430\u001b[0m \u001b[0;34m'in a future version'\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdate\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m'after %s'\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mdate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 431\u001b[0m instructions)\n\u001b[0;32m--> 432\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 433\u001b[0m return tf_decorator.make_decorator(func, new_func, 'deprecated',\n\u001b[1;32m 434\u001b[0m _add_deprecated_arg_notice_to_docstring(\n",
"\u001b[0;32m~/Student_Jan/virtual_environments/python3_5/lib/python3.5/site-packages/tensorflow/python/framework/importer.py\u001b[0m in \u001b[0;36mimport_graph_def\u001b[0;34m(graph_def, input_map, return_elements, name, op_dict, producer_op_list)\u001b[0m\n\u001b[1;32m 420\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mInvalidArgumentError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 421\u001b[0m \u001b[0;31m# Convert to ValueError for backwards compatibility.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 422\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 423\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 424\u001b[0m \u001b[0;31m# Create _DefinedFunctions for any imported functions.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: Input 0 of node bn/cond/ReadVariableOp/Switch was passed float from bn/gamma:0 incompatible with expected resource."
]
}
],
"source": [
"tf.reset_default_graph()\n",
"\n",
"# Load frozen graph\n",
"tf.reset_default_graph()\n",
"frozen_model_filepath = \"/tmp/frozen_model.pb\"\n",
"\n",
"# We use our \"load_graph\" function\n",
"graph = load_frozen_graph(frozen_model_filepath)\n",
"print(\"Frozen graph was loaded succesfully!\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment